import asyncio from collections import defaultdict import dataclasses import json import logging import math from pathlib import Path import random from typing import Any, List, Mapping, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix from omnibot.plugin import Plugin log = logging.getLogger(__name__) ALLCHAIN = "ALL!CHAIN" def chain_inner_default() -> defaultdict[str | None, int]: return defaultdict(int) def chain_default() -> defaultdict[str, defaultdict[str | None, int]]: return defaultdict(chain_inner_default) def windows(items: Sequence[Any], size: int): if len(items) < size: yield items else: for i in range(0, len(items) - size + 1): yield items[i : i + size] @dataclasses.dataclass class Chain: def __init__(self, order: int, chance: float, path: Path): self.order = order self.__reply_chance = chance self.path = path self.__cache = chain_default() self.__last_access = 0.0 self.__dirty = False def __touch(self): self.__last_access = asyncio.get_running_loop().time() @property def reply_chance(self) -> float: self.__load() return self.__reply_chance @reply_chance.setter def reply_chance(self, val: float): if not (isinstance(val, float) or isinstance(val, int)): return NotImplemented self.__load() self.__reply_chance = val self.__dirty = True @property def last_access(self) -> float: return self.__last_access def add(self, text: str): parts: List[Any] = text.strip().split() if not parts: return self.__touch() self.__load() self.__dirty = True for fragment in windows(parts + [None], self.order + 1): head = fragment[0:-1] tail = fragment[-1] self.__cache[" ".join(head)][tail] += 1 def get(self, key: str) -> dict[str | None, int]: if self.__cache: if key in self.__cache: self.__touch() return self.__cache[key] else: raise KeyError(key) else: # Load cache, then return key self.__load() self.__touch() return self.__cache[key] def set(self, key: str, value: Mapping[str | None, int]): self.__touch() if not self.__cache: # Attempt the cache before writing to it self.__load() self.__cache[key] = defaultdict(int, value) self.__dirty = True def __load(self): self.__touch() if self.__cache: return if not self.path.exists(): return with open(self.path) as fp: log.info("Loading markov chain from %s", self.path) obj = json.load(fp) # Load the save object self.__reply_chance = obj["reply_chance"] self.__cache = defaultdict( chain_inner_default, { key: defaultdict( int, { (None if not word else word): weight for word, weight in value.items() }, ) for key, value in obj["chain"].items() }, ) self.__dirty = False def save(self): if not self.__cache: return if self.__dirty: log.info("Saving markov chain to %s", self.path) self.path.parent.mkdir(parents=True, exist_ok=True) # Build the save object obj = { "reply_chance": self.__reply_chance, "chain": { key: { ("" if word is None else word): weight for word, weight in value.items() } for key, value in self.__cache.items() }, } with open(self.path, "w") as fp: json.dump(obj, fp) self.__dirty = False else: log.info("Chain %s is not dirty, not saving", self.path) def clear_cache(self): self.__cache.clear() self.__dirty = False def __bool__(self) -> bool: return self.path.exists() or bool(self.__cache) def generate(self) -> str: self.__load() if not self.__cache: return "" words: List[str] = [] node = random.choice(list(self.__cache.keys())).split(" ") words += node next: str | None = self.choose_next(" ".join(node)) while next: words += [next] node = [*node[1:], next] next = self.choose_next(" ".join(node)) return " ".join(words) def choose_next(self, head: str) -> str | None: self.__load() choices = self.__cache[head] words = list(choices.keys()) weights = list(choices.values()) if not words: return None return random.choices(words, weights)[0] class Markov(Plugin): def __init__(self, *args, **kwargs): super(Markov, self).__init__(*args, **kwargs) self.order = int(self.plugin_config.get("order", 1)) self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) self.save_every = int(self.plugin_config.get("save_every", 1800)) self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01)) self.__chains = {} self.__save_loop_task = None self.__saving = asyncio.Lock() async def on_load(self): loop = asyncio.get_running_loop() self.__save_loop_task = loop.create_task(self.__save_loop()) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() parts = line.split() if not parts: return elif parts[0] == "!markov": self.handle_command(conn, channel, who, parts) elif line[0] != "!": # ignore other commands self.add(channel, who.nick, line) # also, maybe generate a sentence chosen = random.random() chain = self.get_chain(channel, who.nick) if chosen <= chain.reply_chance: message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") def get_chain(self, channel: str, who: str) -> Chain: if channel not in self.__chains: self.__chains[channel] = {} if who not in self.__chains[channel]: path = self.data_path / channel / who self.__chains[channel][who] = Chain(self.order, self.reply_chance, path) return self.__chains[channel][who] def add(self, channel: str, who: str, line: str): if who == self.server_config.nick: return chain = self.get_chain(channel, who) chain.add(line) allchain = self.get_chain(channel, ALLCHAIN) allchain.add(line) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] ): # handle markov commands match parts[1:]: case ["force"]: chain = self.get_chain(channel, who.nick) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: chain = self.get_chain(channel, nick) if not chain: return message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["all"]: chain = self.get_chain(channel, ALLCHAIN) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["chance"]: chain = self.get_chain(channel, who.nick) self.send_to( conn, channel, f"{who.nick}: current reply chance is {chain.reply_chance}", ) case ["chance", chance]: chain = self.get_chain(channel, who.nick) try: reply_chance = float(chance) except ValueError: log.error("Couldn't parse %r as a float", chance) return if not math.isnan(reply_chance): chain.reply_chance = min( max(float(reply_chance), 0.0), self.reply_chance ) self.send_to( conn, channel, f"{who.nick}: reply chance set to {chain.reply_chance}", ) case _: # command not recognized pass async def __save_loop(self): while True: log.debug("Saving markov chains in %s seconds", self.save_every) await asyncio.sleep(self.save_every) retain_after = asyncio.get_running_loop().time() - self.save_every await self.save(retain_after=retain_after) async def save(self, retain_after: float | None = None): async with self.__saving: from concurrent.futures import ProcessPoolExecutor log.info("Saving markov chains") coros = [] loop = asyncio.get_running_loop() # ProcessPoolExecutor is an explicit decision I've made to use, # because it allows us to save in a different process, with # different memory, and simultaneously clear it if it needs to be # cleared. with ProcessPoolExecutor() as pool: for chains in self.__chains.values(): for chain in chains.values(): # Start the save in a new process, in a new task. log.debug("Starting process to save %s", chain.path) coro = loop.run_in_executor(pool, chain.save) coros += [coro] # Prune retain = True if retain_after is not None: retain = chain.last_access > retain_after if not retain: log.info("Pruning markov chain %s from memory", chain.path) chain.clear_cache() if coros: await asyncio.gather(*coros) log.info("Done") async def on_unload(self, conn: IrcProtocol): self.__save_loop_task.cancel() await self.save() PLUGIN_TYPE = Markov