import asyncio from collections import defaultdict import dataclasses import json import logging import math from pathlib import Path import random from typing import Any, List, Mapping, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix from omnibot.plugin import Plugin log = logging.getLogger(__name__) ALLCHAIN = "ALL!CHAIN" def chain_inner_default() -> defaultdict[str | None, int]: return defaultdict(int) def chain_default() -> defaultdict[str, defaultdict[str | None, int]]: return defaultdict(chain_inner_default) def windows(items: Sequence[Any], size: int): if len(items) < size: yield items else: for i in range(0, len(items) - size + 1): yield items[i : i + size] @dataclasses.dataclass class Chain: def __init__(self, order: int, chance: float, path: Path): self.order = order self.reply_chance = chance self.path = path self.__cache = chain_default() self.__last_access = 0.0 self.__dirty = False @property def last_access(self) -> float: return self.__last_access def add(self, text: str): parts: List[Any] = text.strip().split() if not parts: return self.__load() for fragment in windows(parts + [None], self.order + 1): head = fragment[0:-1] tail = fragment[-1] self.__cache[" ".join(head)][tail] += 1 self.__dirty = True def get(self, key: str) -> dict[str | None, int]: self.__last_access = asyncio.get_running_loop().time() if self.__cache: if key in self.__cache: return self.__cache[key] else: raise KeyError(key) else: # Load cache, then return key self.__load() return self.__cache[key] def set(self, key: str, value: Mapping[str | None, int]): self.__last_access = asyncio.get_running_loop().time() if not self.__cache: # Attempt the cache before writing to it self.__load() self.__cache[key] = defaultdict(int, value) self.__dirty = True def __load(self): self.__last_access = asyncio.get_running_loop().time() if self.__cache: return if not self.path.exists(): return with open(self.path) as fp: log.info("Loading markov chain from %s", self.path) obj = json.load(fp) # Load the save object self.reply_chance = obj["reply_chance"] self.__cache = defaultdict( chain_inner_default, { key: defaultdict( int, { (None if not word else word): weight for word, weight in value.items() }, ) for key, value in obj["chain"].items() }, ) self.__dirty = False def save(self): if not self.__cache: return if self.__dirty: log.info("Saving markov chain to %s", self.path) self.path.parent.mkdir(parents=True, exist_ok=True) # Build the save object obj = { "reply_chance": self.reply_chance, "chain": { key: { ("" if word is None else word): weight for word, weight in value.items() } for key, value in self.__cache.items() }, } with open(self.path, "w") as fp: json.dump(obj, fp) self.__dirty = False def clear_cache(self): self.__cache.clear() self.__dirty = False def __bool__(self) -> bool: return self.path.exists() or bool(self.__cache) def generate(self) -> str: self.__load() if not self.__cache: return "" words: List[str] = [] node = random.choice(list(self.__cache.keys())).split(" ") words += node next: str | None = self.choose_next(" ".join(node)) while next: words += [next] node = [*node[1:], next] next = self.choose_next(" ".join(node)) return " ".join(words) def choose_next(self, head: str) -> str | None: self.__load() choices = self.__cache[head] words = list(choices.keys()) weights = list(choices.values()) if not words: return None return random.choices(words, weights)[0] class Markov(Plugin): def __init__(self, *args, **kwargs): super(Markov, self).__init__(*args, **kwargs) self.order = int(self.plugin_config.get("order", 1)) self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) self.save_every = int(self.plugin_config.get("save_every", 1800)) self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01)) self.__chains = {} self.__save_loop_task = None self.__saving = asyncio.Lock() async def on_load(self): loop = asyncio.get_running_loop() self.__save_loop_task = loop.create_task(self.__save_loop()) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() parts = line.split() if not parts: return elif parts[0] == "!markov": self.handle_command(conn, channel, who, parts) elif line[0] != "!": # ignore other commands self.add(channel, who.nick, line) # also, maybe generate a sentence chosen = random.random() chain = self.get_chain(channel, who.nick) if chosen <= chain.reply_chance: pass def get_chain(self, channel: str, who: str) -> Chain: if channel not in self.__chains: self.__chains[channel] = {} if who not in self.__chains[channel]: path = self.data_path / channel / who self.__chains[channel][who] = Chain(self.order, self.reply_chance, path) return self.__chains[channel][who] def add(self, channel: str, who: str, line: str): if who == self.server_config.nick: return chain = self.get_chain(channel, who) chain.add(line) allchain = self.get_chain(channel, ALLCHAIN) allchain.add(line) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] ): # handle markov commands match parts[1:]: case ["force"]: chain = self.get_chain(channel, who.nick) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: chain = self.get_chain(channel, nick) if not chain: return message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["all"]: chain = self.get_chain(channel, ALLCHAIN) message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["chance", chance]: chain = self.get_chain(channel, who.nick) try: reply_chance = float(chance) except ValueError: log.error("Couldn't parse %r as a float", chance) return if not math.isnan(reply_chance): chain.reply_chance = min( max(float(reply_chance), 0.0), self.reply_chance ) self.send_to( conn, channel, f"{who.nick}: reply chance set to {chain.reply_chance}", ) case _: # command not recognized pass async def __save_loop(self): while True: log.debug("Pruning inactive markov chains in %s seconds", self.save_every) await asyncio.sleep(self.save_every) retain_after = asyncio.get_running_loop().time() - self.save_every await self.save(retain_after=retain_after) async def save(self, retain_after: float | None = None): async with self.__saving: log.info("Saving markov chains") for chains in self.__chains.values(): for chain in chains.values(): chain.save() # Prune retain = True if retain_after is not None: retain = chain.last_access > retain_after if not retain: log.debug("Pruning markov chain %s from memory", chain.path) chain.clear_cache() log.info("Done") async def on_unload(self, conn: IrcProtocol): self.__save_loop_task.cancel() await self.save() PLUGIN_TYPE = Markov