from collections import defaultdict import dataclasses from functools import partial import logging from pathlib import Path import pickle import random import shutil from typing import Any, List, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix from omnibot.plugin import Plugin log = logging.getLogger(__name__) def chain_inner_default() -> defaultdict[str | None, int]: return defaultdict(int) def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]: return defaultdict(chain_inner_default) def windows(items: Sequence[Any], size: int): if len(items) < size: yield items else: for i in range(0, len(items) - size + 1): yield items[i : i + size] @dataclasses.dataclass class Chain: order: int chain: defaultdict[ tuple[str, ...], defaultdict[str | None, int] ] = dataclasses.field(default_factory=chain_default) def add(self, text: str): parts: List[Any] = text.strip().split() if not parts: return if len(parts) < self.order: self.chain[tuple(parts)][None] += 1 else: for fragment in windows(parts + [None], self.order + 1): head = tuple(fragment[0:-1]) tail = fragment[-1] self.chain[head][tail] += 1 def generate(self) -> str: words: List[str] = [] if not self.chain: return "" node = random.choice(list(self.chain.keys())) words += node next: str | None = self.choose_next(node) while next: words += [next] node = (*node[1:], next) next = self.choose_next(node) return " ".join(words) def choose_next(self, head: tuple[str, ...]) -> str | None: if head not in self.chain: return None choices = self.chain[head] words = list(choices.keys()) weights = list(choices.values()) if not words: return None return random.choices(words, weights)[0] class Markov(Plugin): def __init__(self, *args, **kwargs): super(Markov, self).__init__(*args, **kwargs) self.order = self.plugin_config.get("order", 1) self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl")) self.backup_db_path = Path(str(self.db_path) + ".backup") # Load chain from data dir if self.db_path.exists(): with open(self.db_path, "rb") as fp: self.chains = pickle.load(fp) else: self.chains = defaultdict( partial(defaultdict, partial(Chain, order=self.order)) ) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() parts = line.split() if not parts: return elif parts[0] == "!markov": self.handle_command(conn, channel, who, parts) elif line[0] != "!": # ignore other commands self.add(channel, who.nick, line) def add(self, channel: str, who: str, line: str): self.chains[channel][who].add(line) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] ): # handle markov commands match parts[1:]: case ["force"]: chain = self.chains[channel][who.nick] message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: if nick not in self.chains[channel]: return chain = self.chains[channel][nick] message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case _: # command not recognized pass def save(self): if self.db_path.exists(): log.info("Copying backup of markov chain to %s") shutil.copyfile(self.db_path, self.backup_db_path) log.info("Done") log.info("Saving markov chain to %s", self.db_path) with open(self.db_path, "wb") as fp: pickle.dump(self.chains, fp) log.info("Done") async def on_unload(self, conn: IrcProtocol): self.save() PLUGIN_TYPE = Markov