markov: Finish up chain database implementation
Markov now uses a sqlite3 database instead of flat JSON files. This should significantly speed up saving time, plus reduce the amount of RAM that it uses. Saving and loading large JSON files was very slow and caused issues with other plugins, especially when messages were received. Additionally, in order to save RAM, a cache was used and periodically flushed when not used, adding some complications to the implementation. This has all been removed since things get committed on the fly with the database implementation. The main trade-off we have to make is the disk space used by the database. This is OK though, because disk space is cheap while RAM is not. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
import random
|
||||
import sqlite3
|
||||
from typing import Any, List, Mapping, Sequence
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from asyncirc.protocol import IrcProtocol
|
||||
from irclib.parser import Prefix
|
||||
@@ -35,12 +32,20 @@ def windows(items: Sequence[Any], size: int):
|
||||
yield items[i : i + size]
|
||||
|
||||
|
||||
class DbChain:
|
||||
def __init__(self, order: int, path: Path):
|
||||
class Chain:
|
||||
def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path):
|
||||
self.order = order
|
||||
self.reply_chance = reply_chance
|
||||
self.path = path
|
||||
self.db = sqlite3.connect(self.path)
|
||||
|
||||
# Run the initial database creation script
|
||||
cursor = self.db.cursor()
|
||||
with open(sql_path) as fp:
|
||||
cursor.executescript(fp.read())
|
||||
cursor.close()
|
||||
self.db.commit()
|
||||
|
||||
def commit(self):
|
||||
self.db.commit()
|
||||
|
||||
@@ -57,10 +62,29 @@ class DbChain:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_reply_chance(self, channel: str, nick: str) -> float:
|
||||
if result := self.execute(
|
||||
"SELECT reply_chance FROM user WHERE channel = ? AND nick = ?",
|
||||
(channel, nick),
|
||||
):
|
||||
return result[0][0]
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def set_reply_chance(self, channel: str, nick: str, chance: float):
|
||||
self.ensure_user(channel, nick)
|
||||
self.execute(
|
||||
"UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?",
|
||||
(chance, channel, nick),
|
||||
)
|
||||
|
||||
def ensure_user(self, channel: str, nick: str):
|
||||
if self.get_user_id(channel, nick):
|
||||
return
|
||||
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
|
||||
self.execute(
|
||||
"INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)",
|
||||
(channel, nick, self.reply_chance),
|
||||
)
|
||||
|
||||
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
||||
assert next is not None
|
||||
@@ -168,172 +192,22 @@ class DbChain:
|
||||
return random.choices(words, weights)[0]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Chain:
|
||||
def __init__(self, order: int, chance: float, path: Path):
|
||||
self.order = order
|
||||
self.__reply_chance = chance
|
||||
self.path = path
|
||||
self.__cache = chain_default()
|
||||
self.__last_access = 0.0
|
||||
self.__dirty = False
|
||||
|
||||
def __touch(self):
|
||||
self.__last_access = asyncio.get_running_loop().time()
|
||||
|
||||
@property
|
||||
def reply_chance(self) -> float:
|
||||
self.__load()
|
||||
return self.__reply_chance
|
||||
|
||||
@reply_chance.setter
|
||||
def reply_chance(self, val: float):
|
||||
if not (isinstance(val, float) or isinstance(val, int)):
|
||||
return NotImplemented
|
||||
self.__load()
|
||||
self.__reply_chance = val
|
||||
self.__dirty = True
|
||||
|
||||
@property
|
||||
def last_access(self) -> float:
|
||||
return self.__last_access
|
||||
|
||||
def add(self, text: str):
|
||||
parts: List[Any] = text.strip().split()
|
||||
if not parts:
|
||||
return
|
||||
self.__touch()
|
||||
self.__load()
|
||||
self.__dirty = True
|
||||
for fragment in windows(parts + [None], self.order + 1):
|
||||
head = fragment[0:-1]
|
||||
tail = fragment[-1]
|
||||
self.__cache[" ".join(head)][tail] += 1
|
||||
|
||||
def get(self, key: str) -> dict[str | None, int]:
|
||||
if self.__cache:
|
||||
if key in self.__cache:
|
||||
self.__touch()
|
||||
return self.__cache[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
# Load cache, then return key
|
||||
self.__load()
|
||||
self.__touch()
|
||||
return self.__cache[key]
|
||||
|
||||
def set(self, key: str, value: Mapping[str | None, int]):
|
||||
self.__touch()
|
||||
if not self.__cache:
|
||||
# Attempt the cache before writing to it
|
||||
self.__load()
|
||||
self.__cache[key] = defaultdict(int, value)
|
||||
self.__dirty = True
|
||||
|
||||
def __load(self):
|
||||
self.__touch()
|
||||
if self.__cache:
|
||||
return
|
||||
if not self.path.exists():
|
||||
return
|
||||
with open(self.path) as fp:
|
||||
log.info("Loading markov chain from %s", self.path)
|
||||
obj = json.load(fp)
|
||||
|
||||
# Load the save object
|
||||
self.__reply_chance = obj["reply_chance"]
|
||||
self.__cache = defaultdict(
|
||||
chain_inner_default,
|
||||
{
|
||||
key: defaultdict(
|
||||
int,
|
||||
{
|
||||
(None if not word else word): weight
|
||||
for word, weight in value.items()
|
||||
},
|
||||
)
|
||||
for key, value in obj["chain"].items()
|
||||
},
|
||||
)
|
||||
self.__dirty = False
|
||||
|
||||
def save(self):
|
||||
if not self.__cache:
|
||||
return
|
||||
if self.__dirty:
|
||||
log.info("Saving markov chain to %s", self.path)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Build the save object
|
||||
obj = {
|
||||
"reply_chance": self.__reply_chance,
|
||||
"chain": {
|
||||
key: {
|
||||
("" if word is None else word): weight
|
||||
for word, weight in value.items()
|
||||
}
|
||||
for key, value in self.__cache.items()
|
||||
},
|
||||
}
|
||||
with open(self.path, "w") as fp:
|
||||
json.dump(obj, fp)
|
||||
self.__dirty = False
|
||||
else:
|
||||
log.info("Chain %s is not dirty, not saving", self.path)
|
||||
|
||||
def clear_cache(self):
|
||||
self.__cache.clear()
|
||||
self.__dirty = False
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.path.exists() or bool(self.__cache)
|
||||
|
||||
def generate(self) -> str:
|
||||
self.__load()
|
||||
if not self.__cache:
|
||||
return ""
|
||||
|
||||
words: List[str] = []
|
||||
|
||||
node = random.choice(list(self.__cache.keys())).split(" ")
|
||||
words += node
|
||||
next: str | None = self.choose_next(" ".join(node))
|
||||
while next:
|
||||
words += [next]
|
||||
node = [*node[1:], next]
|
||||
next = self.choose_next(" ".join(node))
|
||||
return " ".join(words)
|
||||
|
||||
def choose_next(self, head: str) -> str | None:
|
||||
self.__load()
|
||||
choices = self.__cache[head]
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
if not words:
|
||||
return None
|
||||
return random.choices(words, weights)[0]
|
||||
|
||||
|
||||
class Markov(Plugin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Markov, self).__init__(*args, **kwargs)
|
||||
|
||||
self.order = int(self.plugin_config.get("order", 1))
|
||||
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
|
||||
self.save_every = int(self.plugin_config.get("save_every", 1800))
|
||||
self.data_path = Path(
|
||||
self.plugin_config.get("data_path", "data/markov/markov.db")
|
||||
)
|
||||
self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql"))
|
||||
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
|
||||
self.__chains = {}
|
||||
self.__save_loop_task = None
|
||||
self.__saving = asyncio.Lock()
|
||||
|
||||
async def on_load(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
self.__save_loop_task = loop.create_task(self.__save_loop())
|
||||
self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path)
|
||||
|
||||
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
||||
line = line.strip()
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
if not parts or who.nick == self.nick:
|
||||
return
|
||||
elif parts[0] == "!markov":
|
||||
self.handle_command(conn, channel, who, parts)
|
||||
@@ -342,27 +216,17 @@ class Markov(Plugin):
|
||||
self.add(channel, who.nick, line)
|
||||
# also, maybe generate a sentence
|
||||
chosen = random.random()
|
||||
chain = self.get_chain(channel, who.nick)
|
||||
if chosen <= chain.reply_chance:
|
||||
message = chain.generate()
|
||||
reply_chance = self.chain.get_reply_chance(channel, who.nick)
|
||||
if chosen <= reply_chance:
|
||||
message = self.chain.generate(channel, who.nick)
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
|
||||
def get_chain(self, channel: str, who: str) -> Chain:
|
||||
if channel not in self.__chains:
|
||||
self.__chains[channel] = {}
|
||||
if who not in self.__chains[channel]:
|
||||
path = self.data_path / channel / who
|
||||
self.__chains[channel][who] = Chain(self.order, self.reply_chance, path)
|
||||
return self.__chains[channel][who]
|
||||
|
||||
def add(self, channel: str, who: str, line: str):
|
||||
def add(self, channel: str, who: str, line: str, commit=True):
|
||||
if who == self.server_config.nick:
|
||||
return
|
||||
chain = self.get_chain(channel, who)
|
||||
chain.add(line)
|
||||
allchain = self.get_chain(channel, ALLCHAIN)
|
||||
allchain.add(line)
|
||||
self.chain.add(channel, who, line, commit)
|
||||
self.chain.add(channel, ALLCHAIN, line, commit)
|
||||
|
||||
def handle_command(
|
||||
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
||||
@@ -370,87 +234,48 @@ class Markov(Plugin):
|
||||
# handle markov commands
|
||||
match parts[1:]:
|
||||
case ["force"]:
|
||||
chain = self.get_chain(channel, who.nick)
|
||||
message = chain.generate()
|
||||
message = self.chain.generate(channel, who.nick)
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
case ["force", nick] | ["emulate", nick]:
|
||||
chain = self.get_chain(channel, nick)
|
||||
if not chain:
|
||||
if not self.chain.get_user_id(channel, nick):
|
||||
return
|
||||
message = chain.generate()
|
||||
message = self.chain.generate(channel, nick)
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
case ["all"]:
|
||||
chain = self.get_chain(channel, ALLCHAIN)
|
||||
message = chain.generate()
|
||||
message = self.chain.generate(channel, ALLCHAIN)
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
case ["chance"]:
|
||||
chain = self.get_chain(channel, who.nick)
|
||||
chance = self.chain.get_reply_chance(channel, who.nick)
|
||||
self.send_to(
|
||||
conn,
|
||||
channel,
|
||||
f"{who.nick}: current reply chance is {chain.reply_chance}",
|
||||
f"{who.nick}: current reply chance is {chance}",
|
||||
)
|
||||
case ["chance", chance]:
|
||||
chain = self.get_chain(channel, who.nick)
|
||||
try:
|
||||
reply_chance = float(chance)
|
||||
except ValueError:
|
||||
log.error("Couldn't parse %r as a float", chance)
|
||||
return
|
||||
if not math.isnan(reply_chance):
|
||||
chain.reply_chance = min(
|
||||
max(float(reply_chance), 0.0), self.reply_chance
|
||||
)
|
||||
reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance)
|
||||
self.chain.set_reply_chance(channel, who.nick, reply_chance)
|
||||
self.send_to(
|
||||
conn,
|
||||
channel,
|
||||
f"{who.nick}: reply chance set to {chain.reply_chance}",
|
||||
f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}",
|
||||
)
|
||||
case _:
|
||||
# command not recognized
|
||||
pass
|
||||
|
||||
async def __save_loop(self):
|
||||
while True:
|
||||
log.debug("Saving markov chains in %s seconds", self.save_every)
|
||||
await asyncio.sleep(self.save_every)
|
||||
retain_after = asyncio.get_running_loop().time() - self.save_every
|
||||
await self.save(retain_after=retain_after)
|
||||
|
||||
async def save(self, retain_after: float | None = None):
|
||||
async with self.__saving:
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
log.info("Saving markov chains")
|
||||
coros = []
|
||||
loop = asyncio.get_running_loop()
|
||||
# ProcessPoolExecutor is an explicit decision I've made to use,
|
||||
# because it allows us to save in a different process, with
|
||||
# different memory, and simultaneously clear it if it needs to be
|
||||
# cleared.
|
||||
with ProcessPoolExecutor() as pool:
|
||||
for chains in self.__chains.values():
|
||||
for chain in chains.values():
|
||||
# Start the save in a new process, in a new task.
|
||||
log.debug("Starting process to save %s", chain.path)
|
||||
coro = loop.run_in_executor(pool, chain.save)
|
||||
coros += [coro]
|
||||
# Prune
|
||||
retain = True
|
||||
if retain_after is not None:
|
||||
retain = chain.last_access > retain_after
|
||||
if not retain:
|
||||
log.info("Pruning markov chain %s from memory", chain.path)
|
||||
chain.clear_cache()
|
||||
if coros:
|
||||
await asyncio.gather(*coros)
|
||||
log.info("Done")
|
||||
async def save(self):
|
||||
self.chain.commit()
|
||||
|
||||
async def on_unload(self, conn: IrcProtocol):
|
||||
self.__save_loop_task.cancel()
|
||||
await self.save()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user