From 9c188e30b13ee7092a22f5530db28c7fa0c4f9d6 Mon Sep 17 00:00:00 2001 From: Alek Ratzloff Date: Tue, 24 May 2022 19:30:42 -0700 Subject: [PATCH] Add initial markov bot plugin Signed-off-by: Alek Ratzloff --- plugins/markov.py | 142 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 plugins/markov.py diff --git a/plugins/markov.py b/plugins/markov.py new file mode 100644 index 0000000..956773d --- /dev/null +++ b/plugins/markov.py @@ -0,0 +1,142 @@ +from collections import defaultdict +import dataclasses +from functools import partial +import logging +from pathlib import Path +import pickle +import random +import shutil +from typing import Any, List, MutableMapping, Sequence + +from asyncirc.protocol import IrcProtocol +from irclib.parser import Prefix + +from omnibot.plugin import Plugin + + +log = logging.getLogger(__name__) + + +def chain_inner_default() -> defaultdict[str | None, int]: + return defaultdict(int) + + +def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]: + return defaultdict(chain_inner_default) + + +def windows(items: Sequence[Any], size: int): + if len(items) < size: + yield items + else: + for i in range(0, len(items) - size + 1): + yield items[i : i + size] + + +@dataclasses.dataclass +class Chain: + order: int + chain: defaultdict[ + tuple[str, ...], defaultdict[str | None, int] + ] = dataclasses.field(default_factory=chain_default) + + def add(self, text: str): + parts: List[Any] = text.strip().split() + if not parts: + return + if len(parts) < self.order: + self.chain[tuple(parts)][None] += 1 + else: + for fragment in windows(parts + [None], self.order + 1): + head = tuple(fragment[0:-1]) + tail = fragment[-1] + self.chain[head][tail] += 1 + + def generate(self) -> str: + words: List[str] = [] + if not self.chain: + return "" + + node = random.choice(list(self.chain.keys())) + words += node + next: str | None = self.choose_next(node) + while next: + words += [next] + node = (*node[1:], next) + next = self.choose_next(node) + + return " ".join(words) + + def choose_next(self, head: tuple[str, ...]) -> str | None: + if head not in self.chain: + return None + choices = self.chain[head] + words = list(choices.keys()) + weights = list(choices.values()) + if not words: + return None + return random.choices(words, weights)[0] + + +class Markov(Plugin): + def __init__(self, *args, **kwargs): + super(Markov, self).__init__(*args, **kwargs) + + self.order = self.plugin_config.get("order", 1) + self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl")) + self.backup_db_path = Path(str(self.db_path) + ".backup") + + # Load chain from data dir + if self.db_path.exists(): + with open(self.db_path, "rb") as fp: + self.chains = pickle.load(fp) + else: + self.chains = defaultdict( + partial(defaultdict, partial(Chain, order=self.order)) + ) + + async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): + line = line.strip() + parts = line.split() + if not parts: + return + elif parts[0] == "!markov": + self.handle_command(conn, channel, who, parts) + elif line[0] != "!": + # ignore other commands + self.chains[channel][who.nick].add(line) + + def handle_command( + self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] + ): + # handle markov commands + match parts[1:]: + case ["force"]: + chain = self.chains[channel][who.nick] + message = chain.generate() + if message: + self.send_to(conn, channel, f"{who.nick}: {message}") + case ["force", nick] | ["emulate", nick]: + if nick not in self.chains[channel]: + return + chain = self.chains[channel][nick] + message = chain.generate() + if message: + self.send_to(conn, channel, f"{who.nick}: {message}") + case _: + # command not recognized + pass + + def save(self): + if self.db_path.exists(): + log.info("Copying backup of markov chain to %s") + shutil.copyfile(self.db_path, self.backup_db_path) + log.info("Done") + + log.info("Saving markov chain to %s", self.db_path) + with open(self.db_path, "wb") as fp: + pickle.dump(self.chains, fp) + log.info("Done") + + async def on_unload(self, conn: IrcProtocol): + self.save()