markov: Finish up chain database implementation

Markov now uses a sqlite3 database instead of flat JSON files. This
should significantly speed up saving time, plus reduce the amount of RAM
that it uses. Saving and loading large JSON files was very slow and
caused issues with other plugins, especially when messages were
received. Additionally, in order to save RAM, a cache was used and
periodically flushed when not used, adding some complications to the
implementation. This has all been removed since things get committed on
the fly with the database implementation. The main trade-off we have to
make is the disk space used by the database. This is OK though, because
disk space is cheap while RAM is not.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
2022-06-23 12:29:19 -07:00
parent 086ba7706e
commit 8c4bb5ac60
2 changed files with 55 additions and 230 deletions

View File

@@ -1,13 +1,10 @@
import asyncio
from collections import defaultdict from collections import defaultdict
import dataclasses
import json
import logging import logging
import math import math
from pathlib import Path from pathlib import Path
import random import random
import sqlite3 import sqlite3
from typing import Any, List, Mapping, Sequence from typing import Any, List, Sequence
from asyncirc.protocol import IrcProtocol from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix from irclib.parser import Prefix
@@ -35,12 +32,20 @@ def windows(items: Sequence[Any], size: int):
yield items[i : i + size] yield items[i : i + size]
class DbChain: class Chain:
def __init__(self, order: int, path: Path): def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path):
self.order = order self.order = order
self.reply_chance = reply_chance
self.path = path self.path = path
self.db = sqlite3.connect(self.path) self.db = sqlite3.connect(self.path)
# Run the initial database creation script
cursor = self.db.cursor()
with open(sql_path) as fp:
cursor.executescript(fp.read())
cursor.close()
self.db.commit()
def commit(self): def commit(self):
self.db.commit() self.db.commit()
@@ -57,10 +62,29 @@ class DbChain:
else: else:
return None return None
def get_reply_chance(self, channel: str, nick: str) -> float:
if result := self.execute(
"SELECT reply_chance FROM user WHERE channel = ? AND nick = ?",
(channel, nick),
):
return result[0][0]
else:
return 0.0
def set_reply_chance(self, channel: str, nick: str, chance: float):
self.ensure_user(channel, nick)
self.execute(
"UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?",
(chance, channel, nick),
)
def ensure_user(self, channel: str, nick: str): def ensure_user(self, channel: str, nick: str):
if self.get_user_id(channel, nick): if self.get_user_id(channel, nick):
return return
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick)) self.execute(
"INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)",
(channel, nick, self.reply_chance),
)
def ensure_key(self, channel: str, nick: str, key: str, next: str): def ensure_key(self, channel: str, nick: str, key: str, next: str):
assert next is not None assert next is not None
@@ -168,172 +192,22 @@ class DbChain:
return random.choices(words, weights)[0] return random.choices(words, weights)[0]
@dataclasses.dataclass
class Chain:
def __init__(self, order: int, chance: float, path: Path):
self.order = order
self.__reply_chance = chance
self.path = path
self.__cache = chain_default()
self.__last_access = 0.0
self.__dirty = False
def __touch(self):
self.__last_access = asyncio.get_running_loop().time()
@property
def reply_chance(self) -> float:
self.__load()
return self.__reply_chance
@reply_chance.setter
def reply_chance(self, val: float):
if not (isinstance(val, float) or isinstance(val, int)):
return NotImplemented
self.__load()
self.__reply_chance = val
self.__dirty = True
@property
def last_access(self) -> float:
return self.__last_access
def add(self, text: str):
parts: List[Any] = text.strip().split()
if not parts:
return
self.__touch()
self.__load()
self.__dirty = True
for fragment in windows(parts + [None], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
self.__cache[" ".join(head)][tail] += 1
def get(self, key: str) -> dict[str | None, int]:
if self.__cache:
if key in self.__cache:
self.__touch()
return self.__cache[key]
else:
raise KeyError(key)
else:
# Load cache, then return key
self.__load()
self.__touch()
return self.__cache[key]
def set(self, key: str, value: Mapping[str | None, int]):
self.__touch()
if not self.__cache:
# Attempt the cache before writing to it
self.__load()
self.__cache[key] = defaultdict(int, value)
self.__dirty = True
def __load(self):
self.__touch()
if self.__cache:
return
if not self.path.exists():
return
with open(self.path) as fp:
log.info("Loading markov chain from %s", self.path)
obj = json.load(fp)
# Load the save object
self.__reply_chance = obj["reply_chance"]
self.__cache = defaultdict(
chain_inner_default,
{
key: defaultdict(
int,
{
(None if not word else word): weight
for word, weight in value.items()
},
)
for key, value in obj["chain"].items()
},
)
self.__dirty = False
def save(self):
if not self.__cache:
return
if self.__dirty:
log.info("Saving markov chain to %s", self.path)
self.path.parent.mkdir(parents=True, exist_ok=True)
# Build the save object
obj = {
"reply_chance": self.__reply_chance,
"chain": {
key: {
("" if word is None else word): weight
for word, weight in value.items()
}
for key, value in self.__cache.items()
},
}
with open(self.path, "w") as fp:
json.dump(obj, fp)
self.__dirty = False
else:
log.info("Chain %s is not dirty, not saving", self.path)
def clear_cache(self):
self.__cache.clear()
self.__dirty = False
def __bool__(self) -> bool:
return self.path.exists() or bool(self.__cache)
def generate(self) -> str:
self.__load()
if not self.__cache:
return ""
words: List[str] = []
node = random.choice(list(self.__cache.keys())).split(" ")
words += node
next: str | None = self.choose_next(" ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(" ".join(node))
return " ".join(words)
def choose_next(self, head: str) -> str | None:
self.__load()
choices = self.__cache[head]
words = list(choices.keys())
weights = list(choices.values())
if not words:
return None
return random.choices(words, weights)[0]
class Markov(Plugin): class Markov(Plugin):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs) super(Markov, self).__init__(*args, **kwargs)
self.order = int(self.plugin_config.get("order", 1)) self.order = int(self.plugin_config.get("order", 1))
self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) self.data_path = Path(
self.save_every = int(self.plugin_config.get("save_every", 1800)) self.plugin_config.get("data_path", "data/markov/markov.db")
)
self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql"))
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01)) self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
self.__chains = {} self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path)
self.__save_loop_task = None
self.__saving = asyncio.Lock()
async def on_load(self):
loop = asyncio.get_running_loop()
self.__save_loop_task = loop.create_task(self.__save_loop())
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip() line = line.strip()
parts = line.split() parts = line.split()
if not parts: if not parts or who.nick == self.nick:
return return
elif parts[0] == "!markov": elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts) self.handle_command(conn, channel, who, parts)
@@ -342,27 +216,17 @@ class Markov(Plugin):
self.add(channel, who.nick, line) self.add(channel, who.nick, line)
# also, maybe generate a sentence # also, maybe generate a sentence
chosen = random.random() chosen = random.random()
chain = self.get_chain(channel, who.nick) reply_chance = self.chain.get_reply_chance(channel, who.nick)
if chosen <= chain.reply_chance: if chosen <= reply_chance:
message = chain.generate() message = self.chain.generate(channel, who.nick)
if message: if message:
self.send_to(conn, channel, f"{who.nick}: {message}") self.send_to(conn, channel, f"{who.nick}: {message}")
def get_chain(self, channel: str, who: str) -> Chain: def add(self, channel: str, who: str, line: str, commit=True):
if channel not in self.__chains:
self.__chains[channel] = {}
if who not in self.__chains[channel]:
path = self.data_path / channel / who
self.__chains[channel][who] = Chain(self.order, self.reply_chance, path)
return self.__chains[channel][who]
def add(self, channel: str, who: str, line: str):
if who == self.server_config.nick: if who == self.server_config.nick:
return return
chain = self.get_chain(channel, who) self.chain.add(channel, who, line, commit)
chain.add(line) self.chain.add(channel, ALLCHAIN, line, commit)
allchain = self.get_chain(channel, ALLCHAIN)
allchain.add(line)
def handle_command( def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
@@ -370,87 +234,48 @@ class Markov(Plugin):
# handle markov commands # handle markov commands
match parts[1:]: match parts[1:]:
case ["force"]: case ["force"]:
chain = self.get_chain(channel, who.nick) message = self.chain.generate(channel, who.nick)
message = chain.generate()
if message: if message:
self.send_to(conn, channel, f"{who.nick}: {message}") self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]: case ["force", nick] | ["emulate", nick]:
chain = self.get_chain(channel, nick) if not self.chain.get_user_id(channel, nick):
if not chain:
return return
message = chain.generate() message = self.chain.generate(channel, nick)
if message: if message:
self.send_to(conn, channel, f"{who.nick}: {message}") self.send_to(conn, channel, f"{who.nick}: {message}")
case ["all"]: case ["all"]:
chain = self.get_chain(channel, ALLCHAIN) message = self.chain.generate(channel, ALLCHAIN)
message = chain.generate()
if message: if message:
self.send_to(conn, channel, f"{who.nick}: {message}") self.send_to(conn, channel, f"{who.nick}: {message}")
case ["chance"]: case ["chance"]:
chain = self.get_chain(channel, who.nick) chance = self.chain.get_reply_chance(channel, who.nick)
self.send_to( self.send_to(
conn, conn,
channel, channel,
f"{who.nick}: current reply chance is {chain.reply_chance}", f"{who.nick}: current reply chance is {chance}",
) )
case ["chance", chance]: case ["chance", chance]:
chain = self.get_chain(channel, who.nick)
try: try:
reply_chance = float(chance) reply_chance = float(chance)
except ValueError: except ValueError:
log.error("Couldn't parse %r as a float", chance) log.error("Couldn't parse %r as a float", chance)
return return
if not math.isnan(reply_chance): if not math.isnan(reply_chance):
chain.reply_chance = min( reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance)
max(float(reply_chance), 0.0), self.reply_chance self.chain.set_reply_chance(channel, who.nick, reply_chance)
)
self.send_to( self.send_to(
conn, conn,
channel, channel,
f"{who.nick}: reply chance set to {chain.reply_chance}", f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}",
) )
case _: case _:
# command not recognized # command not recognized
pass pass
async def __save_loop(self): async def save(self):
while True: self.chain.commit()
log.debug("Saving markov chains in %s seconds", self.save_every)
await asyncio.sleep(self.save_every)
retain_after = asyncio.get_running_loop().time() - self.save_every
await self.save(retain_after=retain_after)
async def save(self, retain_after: float | None = None):
async with self.__saving:
from concurrent.futures import ProcessPoolExecutor
log.info("Saving markov chains")
coros = []
loop = asyncio.get_running_loop()
# ProcessPoolExecutor is an explicit decision I've made to use,
# because it allows us to save in a different process, with
# different memory, and simultaneously clear it if it needs to be
# cleared.
with ProcessPoolExecutor() as pool:
for chains in self.__chains.values():
for chain in chains.values():
# Start the save in a new process, in a new task.
log.debug("Starting process to save %s", chain.path)
coro = loop.run_in_executor(pool, chain.save)
coros += [coro]
# Prune
retain = True
if retain_after is not None:
retain = chain.last_access > retain_after
if not retain:
log.info("Pruning markov chain %s from memory", chain.path)
chain.clear_cache()
if coros:
await asyncio.gather(*coros)
log.info("Done")
async def on_unload(self, conn: IrcProtocol): async def on_unload(self, conn: IrcProtocol):
self.__save_loop_task.cancel()
await self.save() await self.save()

View File

@@ -40,7 +40,7 @@ async def main():
name = mat["name"] name = mat["name"]
message = mat["message"].strip() message = mat["message"].strip()
if name != server_config.nick and message and message[0] != "!": if name != server_config.nick and message and message[0] != "!":
plugin.add(channel, name, message) plugin.add(channel, name, message, commit=False)
await plugin.save() await plugin.save()