import asyncio from collections import defaultdict import dataclasses import json import logging from pathlib import Path import random from typing import Any, List, Mapping, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix from omnibot.plugin import Plugin log = logging.getLogger(__name__) ALLCHAIN = "ALL!CHAIN" def chain_inner_default() -> defaultdict[str | None, int]: return defaultdict(int) def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]: return defaultdict(chain_inner_default) def windows(items: Sequence[Any], size: int): if len(items) < size: yield items else: for i in range(0, len(items) - size + 1): yield items[i : i + size] @dataclasses.dataclass class Chain: def __init__(self, order: int, path: Path): self.order = order self.path = path self.__cache = chain_default() self.__last_access = 0.0 @property def last_access(self) -> float: return self.__last_access def add(self, text: str): parts: List[Any] = text.strip().split() if not parts: return self.__load() for fragment in windows(parts + [None], self.order + 1): head = tuple(fragment[0:-1]) tail = fragment[-1] self.__cache[head][tail] += 1 def get(self, key: tuple[str, ...]) -> dict[str | None, int]: self.__last_access = asyncio.get_running_loop().time() if self.__cache: if key in self.__cache: return self.__cache[key] else: raise KeyError(key) else: # Load cache, then return key self.__load() return self.__cache[key] def set(self, key: tuple[str, ...], value: Mapping[str | None, int]): self.__last_access = asyncio.get_running_loop().time() if not self.__cache: # Attempt the cache before writing to it self.__load() self.__cache[key] = defaultdict(int, value) def __load(self): self.__last_access = asyncio.get_running_loop().time() if self.__cache: return if not self.path.exists(): return with open(self.path) as fp: log.info("Loading markov chain from %s", self.path) self.__cache = defaultdict( chain_inner_default, { tuple(key.split(" ")): defaultdict( int, { (None if not word else word): weight for word, weight in value.items() }, ) for key, value in json.load(fp).items() }, ) def save(self, retain: bool = True): if not self.__cache: return if retain: log.info("Saving markov chain to %s", self.path) else: log.info("Saving markov chain to %s (not retaining)", self.path) self.path.parent.mkdir(parents=True, exist_ok=True) with open(self.path, "w") as fp: json.dump( { " ".join(key): { ("" if word is None else word): weight for word, weight in value.items() } for key, value in self.__cache.items() }, fp, ) if not retain: self.clear_cache() def clear_cache(self): self.__cache.clear() def __bool__(self) -> bool: return self.path.exists() or bool(self.__cache) def generate(self) -> str: self.__load() if not self.__cache: return "" words: List[str] = [] node = random.choice(list(self.__cache.keys())) words += node next: str | None = self.choose_next(node) while next: words += [next] node = (*node[1:], next) next = self.choose_next(node) return " ".join(words) def choose_next(self, head: tuple[str, ...]) -> str | None: self.__load() choices = self.__cache[head] words = list(choices.keys()) weights = list(choices.values()) if not words: return None return random.choices(words, weights)[0] class Markov(Plugin): def __init__(self, *args, **kwargs): super(Markov, self).__init__(*args, **kwargs) self.order = self.plugin_config.get("order", 1) self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) self.save_every = self.plugin_config.get("save_every", 300) self.__chains = {} self.__save_loop_task = None self.__saving = asyncio.Lock() async def on_load(self): loop = asyncio.get_running_loop() self.__save_loop_task = loop.create_task(self.__save_loop()) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() parts = line.split() if not parts: return elif parts[0] == "!markov": self.handle_command(conn, channel, who, parts) elif line[0] != "!": # ignore other commands self.add(channel, who.nick, line) def get_chain(self, channel: str, who: str) -> Chain: if channel not in self.__chains: self.__chains[channel] = {} if who not in self.__chains[channel]: path = self.data_path / channel / who self.__chains[channel][who] = Chain(self.order, path) return self.__chains[channel][who] def add(self, channel: str, who: str, line: str): if who == self.server_config.nick == who: return chain = self.get_chain(channel, who) chain.add(line) allchain = self.get_chain(channel, ALLCHAIN) allchain.add(line) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] ): # handle markov commands match parts[1:]: case ["force"]: chain = self.get_chain(channel, who.nick) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: chain = self.get_chain(channel, who.nick) if not chain: return message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["all"]: chain = self.get_chain(channel, ALLCHAIN) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case _: # command not recognized pass async def __save_loop(self): while True: log.debug("Pruning inactive markov chains in %s seconds", self.save_every) await asyncio.sleep(self.save_every) retain_after = asyncio.get_running_loop().time() - self.save_every await self.save(retain_after=retain_after) async def save(self, retain_after: float | None = None): async with self.__saving: log.info("Saving markov chains") for chains in self.__chains.values(): for chain in chains.values(): if retain_after is not None: retain = chain.last_access > retain_after else: retain = True chain.save(retain=retain) log.info("Done") async def on_unload(self, conn: IrcProtocol): self.__save_loop_task.cancel() await self.save() PLUGIN_TYPE = Markov