diff --git a/plugins/markov.py b/plugins/markov.py index f24ca7b..69a2b89 100644 --- a/plugins/markov.py +++ b/plugins/markov.py @@ -1,13 +1,10 @@ -import asyncio from collections import defaultdict -import dataclasses -import json import logging import math from pathlib import Path import random import sqlite3 -from typing import Any, List, Mapping, Sequence +from typing import Any, List, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix @@ -35,12 +32,20 @@ def windows(items: Sequence[Any], size: int): yield items[i : i + size] -class DbChain: - def __init__(self, order: int, path: Path): +class Chain: + def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path): self.order = order + self.reply_chance = reply_chance self.path = path self.db = sqlite3.connect(self.path) + # Run the initial database creation script + cursor = self.db.cursor() + with open(sql_path) as fp: + cursor.executescript(fp.read()) + cursor.close() + self.db.commit() + def commit(self): self.db.commit() @@ -57,10 +62,29 @@ class DbChain: else: return None + def get_reply_chance(self, channel: str, nick: str) -> float: + if result := self.execute( + "SELECT reply_chance FROM user WHERE channel = ? AND nick = ?", + (channel, nick), + ): + return result[0][0] + else: + return 0.0 + + def set_reply_chance(self, channel: str, nick: str, chance: float): + self.ensure_user(channel, nick) + self.execute( + "UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?", + (chance, channel, nick), + ) + def ensure_user(self, channel: str, nick: str): if self.get_user_id(channel, nick): return - self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick)) + self.execute( + "INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)", + (channel, nick, self.reply_chance), + ) def ensure_key(self, channel: str, nick: str, key: str, next: str): assert next is not None @@ -168,172 +192,22 @@ class DbChain: return random.choices(words, weights)[0] -@dataclasses.dataclass -class Chain: - def __init__(self, order: int, chance: float, path: Path): - self.order = order - self.__reply_chance = chance - self.path = path - self.__cache = chain_default() - self.__last_access = 0.0 - self.__dirty = False - - def __touch(self): - self.__last_access = asyncio.get_running_loop().time() - - @property - def reply_chance(self) -> float: - self.__load() - return self.__reply_chance - - @reply_chance.setter - def reply_chance(self, val: float): - if not (isinstance(val, float) or isinstance(val, int)): - return NotImplemented - self.__load() - self.__reply_chance = val - self.__dirty = True - - @property - def last_access(self) -> float: - return self.__last_access - - def add(self, text: str): - parts: List[Any] = text.strip().split() - if not parts: - return - self.__touch() - self.__load() - self.__dirty = True - for fragment in windows(parts + [None], self.order + 1): - head = fragment[0:-1] - tail = fragment[-1] - self.__cache[" ".join(head)][tail] += 1 - - def get(self, key: str) -> dict[str | None, int]: - if self.__cache: - if key in self.__cache: - self.__touch() - return self.__cache[key] - else: - raise KeyError(key) - else: - # Load cache, then return key - self.__load() - self.__touch() - return self.__cache[key] - - def set(self, key: str, value: Mapping[str | None, int]): - self.__touch() - if not self.__cache: - # Attempt the cache before writing to it - self.__load() - self.__cache[key] = defaultdict(int, value) - self.__dirty = True - - def __load(self): - self.__touch() - if self.__cache: - return - if not self.path.exists(): - return - with open(self.path) as fp: - log.info("Loading markov chain from %s", self.path) - obj = json.load(fp) - - # Load the save object - self.__reply_chance = obj["reply_chance"] - self.__cache = defaultdict( - chain_inner_default, - { - key: defaultdict( - int, - { - (None if not word else word): weight - for word, weight in value.items() - }, - ) - for key, value in obj["chain"].items() - }, - ) - self.__dirty = False - - def save(self): - if not self.__cache: - return - if self.__dirty: - log.info("Saving markov chain to %s", self.path) - self.path.parent.mkdir(parents=True, exist_ok=True) - # Build the save object - obj = { - "reply_chance": self.__reply_chance, - "chain": { - key: { - ("" if word is None else word): weight - for word, weight in value.items() - } - for key, value in self.__cache.items() - }, - } - with open(self.path, "w") as fp: - json.dump(obj, fp) - self.__dirty = False - else: - log.info("Chain %s is not dirty, not saving", self.path) - - def clear_cache(self): - self.__cache.clear() - self.__dirty = False - - def __bool__(self) -> bool: - return self.path.exists() or bool(self.__cache) - - def generate(self) -> str: - self.__load() - if not self.__cache: - return "" - - words: List[str] = [] - - node = random.choice(list(self.__cache.keys())).split(" ") - words += node - next: str | None = self.choose_next(" ".join(node)) - while next: - words += [next] - node = [*node[1:], next] - next = self.choose_next(" ".join(node)) - return " ".join(words) - - def choose_next(self, head: str) -> str | None: - self.__load() - choices = self.__cache[head] - words = list(choices.keys()) - weights = list(choices.values()) - if not words: - return None - return random.choices(words, weights)[0] - - class Markov(Plugin): def __init__(self, *args, **kwargs): super(Markov, self).__init__(*args, **kwargs) self.order = int(self.plugin_config.get("order", 1)) - self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) - self.save_every = int(self.plugin_config.get("save_every", 1800)) + self.data_path = Path( + self.plugin_config.get("data_path", "data/markov/markov.db") + ) + self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql")) self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01)) - self.__chains = {} - self.__save_loop_task = None - self.__saving = asyncio.Lock() - - async def on_load(self): - loop = asyncio.get_running_loop() - self.__save_loop_task = loop.create_task(self.__save_loop()) + self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() parts = line.split() - if not parts: + if not parts or who.nick == self.nick: return elif parts[0] == "!markov": self.handle_command(conn, channel, who, parts) @@ -342,27 +216,17 @@ class Markov(Plugin): self.add(channel, who.nick, line) # also, maybe generate a sentence chosen = random.random() - chain = self.get_chain(channel, who.nick) - if chosen <= chain.reply_chance: - message = chain.generate() + reply_chance = self.chain.get_reply_chance(channel, who.nick) + if chosen <= reply_chance: + message = self.chain.generate(channel, who.nick) if message: self.send_to(conn, channel, f"{who.nick}: {message}") - def get_chain(self, channel: str, who: str) -> Chain: - if channel not in self.__chains: - self.__chains[channel] = {} - if who not in self.__chains[channel]: - path = self.data_path / channel / who - self.__chains[channel][who] = Chain(self.order, self.reply_chance, path) - return self.__chains[channel][who] - - def add(self, channel: str, who: str, line: str): + def add(self, channel: str, who: str, line: str, commit=True): if who == self.server_config.nick: return - chain = self.get_chain(channel, who) - chain.add(line) - allchain = self.get_chain(channel, ALLCHAIN) - allchain.add(line) + self.chain.add(channel, who, line, commit) + self.chain.add(channel, ALLCHAIN, line, commit) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] @@ -370,87 +234,48 @@ class Markov(Plugin): # handle markov commands match parts[1:]: case ["force"]: - chain = self.get_chain(channel, who.nick) - message = chain.generate() + message = self.chain.generate(channel, who.nick) if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: - chain = self.get_chain(channel, nick) - if not chain: + if not self.chain.get_user_id(channel, nick): return - message = chain.generate() + message = self.chain.generate(channel, nick) if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["all"]: - chain = self.get_chain(channel, ALLCHAIN) - message = chain.generate() + message = self.chain.generate(channel, ALLCHAIN) if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["chance"]: - chain = self.get_chain(channel, who.nick) + chance = self.chain.get_reply_chance(channel, who.nick) self.send_to( conn, channel, - f"{who.nick}: current reply chance is {chain.reply_chance}", + f"{who.nick}: current reply chance is {chance}", ) case ["chance", chance]: - chain = self.get_chain(channel, who.nick) try: reply_chance = float(chance) except ValueError: log.error("Couldn't parse %r as a float", chance) return if not math.isnan(reply_chance): - chain.reply_chance = min( - max(float(reply_chance), 0.0), self.reply_chance - ) + reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance) + self.chain.set_reply_chance(channel, who.nick, reply_chance) self.send_to( conn, channel, - f"{who.nick}: reply chance set to {chain.reply_chance}", + f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}", ) case _: # command not recognized pass - async def __save_loop(self): - while True: - log.debug("Saving markov chains in %s seconds", self.save_every) - await asyncio.sleep(self.save_every) - retain_after = asyncio.get_running_loop().time() - self.save_every - await self.save(retain_after=retain_after) - - async def save(self, retain_after: float | None = None): - async with self.__saving: - from concurrent.futures import ProcessPoolExecutor - - log.info("Saving markov chains") - coros = [] - loop = asyncio.get_running_loop() - # ProcessPoolExecutor is an explicit decision I've made to use, - # because it allows us to save in a different process, with - # different memory, and simultaneously clear it if it needs to be - # cleared. - with ProcessPoolExecutor() as pool: - for chains in self.__chains.values(): - for chain in chains.values(): - # Start the save in a new process, in a new task. - log.debug("Starting process to save %s", chain.path) - coro = loop.run_in_executor(pool, chain.save) - coros += [coro] - # Prune - retain = True - if retain_after is not None: - retain = chain.last_access > retain_after - if not retain: - log.info("Pruning markov chain %s from memory", chain.path) - chain.clear_cache() - if coros: - await asyncio.gather(*coros) - log.info("Done") + async def save(self): + self.chain.commit() async def on_unload(self, conn: IrcProtocol): - self.__save_loop_task.cancel() await self.save() diff --git a/tools/markov_import.py b/tools/markov_import.py index e876876..6f4c4b3 100644 --- a/tools/markov_import.py +++ b/tools/markov_import.py @@ -40,7 +40,7 @@ async def main(): name = mat["name"] message = mat["message"].strip() if name != server_config.nick and message and message[0] != "!": - plugin.add(channel, name, message) + plugin.add(channel, name, message, commit=False) await plugin.save()