diff --git a/plugins/markov.py b/plugins/markov.py index ae6fc53..f297ae9 100644 --- a/plugins/markov.py +++ b/plugins/markov.py @@ -1,12 +1,11 @@ +import asyncio from collections import defaultdict import dataclasses -from functools import partial +import json import logging from pathlib import Path -import pickle import random -import shutil -from typing import Any, List, Sequence +from typing import Any, List, Mapping, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix @@ -35,29 +34,103 @@ def windows(items: Sequence[Any], size: int): @dataclasses.dataclass class Chain: - order: int - chain: defaultdict[ - tuple[str, ...], defaultdict[str | None, int] - ] = dataclasses.field(default_factory=chain_default) + def __init__(self, order: int, path: Path): + self.order = order + self.path = path + self.__cache = chain_default() + self.__last_access = 0.0 + + @property + def last_access(self) -> float: + return self.__last_access def add(self, text: str): parts: List[Any] = text.strip().split() if not parts: return - if len(parts) < self.order: - self.chain[tuple(parts)][None] += 1 + self.__load() + for fragment in windows(parts + [None], self.order + 1): + head = tuple(fragment[0:-1]) + tail = fragment[-1] + self.__cache[head][tail] += 1 + + def get(self, key: tuple[str, ...]) -> dict[str | None, int]: + self.__last_access = asyncio.get_running_loop().time() + if self.__cache: + if key in self.__cache: + return self.__cache[key] + else: + raise KeyError(key) else: - for fragment in windows(parts + [None], self.order + 1): - head = tuple(fragment[0:-1]) - tail = fragment[-1] - self.chain[head][tail] += 1 + # Load cache, then return key + self.__load() + return self.__cache[key] + + def set(self, key: tuple[str, ...], value: Mapping[str | None, int]): + self.__last_access = asyncio.get_running_loop().time() + if not self.__cache: + # Attempt the cache before writing to it + self.__load() + self.__cache[key] = defaultdict(int, value) + + def __load(self): + self.__last_access = asyncio.get_running_loop().time() + if self.__cache: + return + if not self.path.exists(): + return + with open(self.path) as fp: + log.info("Loading markov chain from %s", self.path) + self.__cache = defaultdict( + chain_inner_default, + { + tuple(key.split(" ")): defaultdict( + int, + { + (None if not word else word): weight + for word, weight in value.items() + }, + ) + for key, value in json.load(fp).items() + }, + ) + + def save(self, retain: bool = True): + if not self.__cache: + return + if retain: + log.info("Saving markov chain to %s", self.path) + else: + log.info("Saving markov chain to %s (not retaining)", self.path) + self.path.parent.mkdir(parents=True, exist_ok=True) + with open(self.path, "w") as fp: + json.dump( + { + " ".join(key): { + ("" if word is None else word): weight + for word, weight in value.items() + } + for key, value in self.__cache.items() + }, + fp, + ) + if not retain: + self.clear() + + def clear(self): + self.__cache.clear() + + def __bool__(self) -> bool: + return self.path.exists() or bool(self.__cache) def generate(self) -> str: - words: List[str] = [] - if not self.chain: + self.__load() + if not self.__cache: return "" - node = random.choice(list(self.chain.keys())) + words: List[str] = [] + + node = random.choice(list(self.__cache.keys())) words += node next: str | None = self.choose_next(node) while next: @@ -68,9 +141,8 @@ class Chain: return " ".join(words) def choose_next(self, head: tuple[str, ...]) -> str | None: - if head not in self.chain: - return None - choices = self.chain[head] + self.__load() + choices = self.__cache[head] words = list(choices.keys()) weights = list(choices.values()) if not words: @@ -83,17 +155,15 @@ class Markov(Plugin): super(Markov, self).__init__(*args, **kwargs) self.order = self.plugin_config.get("order", 1) - self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl")) - self.backup_db_path = Path(str(self.db_path) + ".backup") + self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) + self.save_every = self.plugin_config.get("save_every", 300) + self.__chains = {} + self.__save_loop_task = None + self.__saving = asyncio.Lock() - # Load chain from data dir - if self.db_path.exists(): - with open(self.db_path, "rb") as fp: - self.chains = pickle.load(fp) - else: - self.chains = defaultdict( - partial(defaultdict, partial(Chain, order=self.order)) - ) + async def on_load(self): + loop = asyncio.get_running_loop() + self.__save_loop_task = loop.create_task(self.__save_loop()) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() @@ -106,8 +176,17 @@ class Markov(Plugin): # ignore other commands self.add(channel, who.nick, line) + def get_chain(self, channel: str, who: str) -> Chain: + if channel not in self.__chains: + self.__chains[channel] = {} + if who not in self.__chains[channel]: + path = self.data_path / channel / who + self.__chains[channel][who] = Chain(self.order, path) + return self.__chains[channel][who] + def add(self, channel: str, who: str, line: str): - self.chains[channel][who].add(line) + chain = self.get_chain(channel, who) + chain.add(line) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] @@ -115,14 +194,14 @@ class Markov(Plugin): # handle markov commands match parts[1:]: case ["force"]: - chain = self.chains[channel][who.nick] + chain = self.get_chain(channel, who.nick) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: - if nick not in self.chains[channel]: + chain = self.get_chain(channel, who.nick) + if not chain: return - chain = self.chains[channel][nick] message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") @@ -130,19 +209,28 @@ class Markov(Plugin): # command not recognized pass - def save(self): - if self.db_path.exists(): - log.info("Copying backup of markov chain to %s") - shutil.copyfile(self.db_path, self.backup_db_path) + async def __save_loop(self): + while True: + log.debug("Pruning inactive markov chains in %s seconds", self.save_every) + await asyncio.sleep(self.save_every) + retain_after = asyncio.get_running_loop().time() - self.save_every + await self.save(retain_after=retain_after) + + async def save(self, retain_after: float | None = None): + async with self.__saving: + log.info("Saving markov chains") + for chains in self.__chains.values(): + for chain in chains.values(): + if retain_after is not None: + retain = chain.last_access > retain_after + else: + retain = True + chain.save(retain=retain) log.info("Done") - log.info("Saving markov chain to %s", self.db_path) - with open(self.db_path, "wb") as fp: - pickle.dump(self.chains, fp) - log.info("Done") - async def on_unload(self, conn: IrcProtocol): - self.save() + self.__save_loop_task.cancel() + await self.save() PLUGIN_TYPE = Markov diff --git a/tools/markov_import.py b/tools/markov_import.py index 5e9a192..43af4fd 100644 --- a/tools/markov_import.py +++ b/tools/markov_import.py @@ -1,11 +1,13 @@ -from omnibot.config import ServerConfig -from plugins.markov import Markov import logging +import asyncio import sys import re +from omnibot.config import ServerConfig +from plugins.markov import Markov -if __name__ == "__main__": + +async def main(): """ Hacky "load my IRC logs" script """ @@ -38,4 +40,8 @@ if __name__ == "__main__": name = mat["name"] message = mat["message"] plugin.add(channel, name, message) - plugin.save() + await plugin.save() + + +if __name__ == "__main__": + asyncio.run(main())