from collections import defaultdict import logging import math from pathlib import Path import random import sqlite3 from typing import Any, List, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix from omnibot.plugin import Plugin log = logging.getLogger(__name__) ALLCHAIN = "ALL!CHAIN" def chain_inner_default() -> defaultdict[str | None, int]: return defaultdict(int) def chain_default() -> defaultdict[str, defaultdict[str | None, int]]: return defaultdict(chain_inner_default) def windows(items: Sequence[Any], size: int): if len(items) < size: yield items else: for i in range(0, len(items) - size + 1): yield items[i : i + size] class Chain: def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path): self.order = order self.reply_chance = reply_chance self.path = path self.db = sqlite3.connect(self.path) # Run the initial database creation script cursor = self.db.cursor() with open(sql_path) as fp: cursor.executescript(fp.read()) cursor.close() self.db.commit() def commit(self): self.db.commit() def execute(self, *args, **kwargs): cursor = self.db.cursor() cursor.execute(*args, **kwargs) return list(cursor.fetchall()) def get_user_id(self, channel: str, nick: str) -> int | None: if result := self.execute( "SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick) ): return result[0][0] else: return None def get_reply_chance(self, channel: str, nick: str) -> float: if result := self.execute( "SELECT reply_chance FROM user WHERE channel = ? AND nick = ?", (channel, nick), ): return result[0][0] else: return 0.0 def set_reply_chance(self, channel: str, nick: str, chance: float): self.ensure_user(channel, nick) self.execute( "UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?", (chance, channel, nick), ) def ensure_user(self, channel: str, nick: str): if self.get_user_id(channel, nick): return self.execute( "INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)", (channel, nick, self.reply_chance), ) def ensure_key(self, channel: str, nick: str, key: str, next: str): assert next is not None self.ensure_user(channel, nick) if next in self.get(channel, nick, key): return self.execute( """ INSERT INTO chain (user, value, weight, next) VALUES ( (SELECT id FROM user WHERE channel = ? AND nick = ?), ?, 0, ? ) """, (channel, nick, key, next), ) def add(self, channel: str, nick: str, text: str, commit=True): parts: List[Any] = text.strip().split() if not parts: return for fragment in windows(parts + [""], self.order + 1): head = fragment[0:-1] tail = fragment[-1] key = " ".join(head) self.update_chain(channel, nick, key, tail) if commit: self.commit() def update_chain( self, channel: str, nick: str, key: str, next: str, weight: int = 1, ): self.ensure_key(channel, nick, key, next) # Get if the key exists self.execute( """ UPDATE chain SET weight = weight + :weight WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick) AND value = :key AND next = :next """, { "channel": channel, "nick": nick, "key": key, "next": next, "weight": weight, }, ) def get(self, channel: str, nick: str, key: str) -> dict[str, int]: cursor = self.db.cursor() cursor.execute( """ SELECT next, weight FROM chain WHERE user = (SELECT id FROM user WHERE channel = ? AND nick = ?) AND value = ? """, (channel, nick, key), ) return {next: weight for next, weight in cursor.fetchall()} def generate(self, channel: str, nick: str) -> str | None: user_id = self.get_user_id(channel, nick) if not user_id: return None words: List[str] = [] cursor = self.db.cursor() cursor.execute( """ SELECT value FROM chain WHERE user = ? ORDER BY RANDOM() LIMIT 1 """, (user_id,), ) node = cursor.fetchone()[0].split(" ") words += node next: str = self.choose_next(channel, nick, " ".join(node)) while next: words += [next] node = [*node[1:], next] next = self.choose_next(channel, nick, " ".join(node)) return " ".join(words) def choose_next(self, channel: str, nick: str, head: str) -> str: choices = self.get(channel, nick, head) words = list(choices.keys()) weights = list(choices.values()) if not words: return "" return random.choices(words, weights)[0] class Markov(Plugin): def __init__(self, *args, **kwargs): super(Markov, self).__init__(*args, **kwargs) self.order = int(self.plugin_config.get("order", 1)) self.data_path = Path( self.plugin_config.get("data_path", "data/markov/markov.db") ) self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql")) self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01)) self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path) async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): line = line.strip() parts = line.split() if not parts or who.nick == self.nick: return elif parts[0] == "!markov": self.handle_command(conn, channel, who, parts) elif line[0] != "!": # ignore other commands self.add(channel, who.nick, line) # also, maybe generate a sentence chosen = random.random() reply_chance = self.chain.get_reply_chance(channel, who.nick) if chosen <= reply_chance: message = self.chain.generate(channel, who.nick) if message: self.send_to(conn, channel, f"{who.nick}: {message}") def add(self, channel: str, who: str, line: str, commit=True): if who == self.server_config.nick: return self.chain.add(channel, who, line, commit) self.chain.add(channel, ALLCHAIN, line, commit) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] ): # handle markov commands match parts[1:]: case ["force"]: message = self.chain.generate(channel, who.nick) if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: if not self.chain.get_user_id(channel, nick): return message = self.chain.generate(channel, nick) if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["all"]: message = self.chain.generate(channel, ALLCHAIN) if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["chance"]: chance = self.chain.get_reply_chance(channel, who.nick) self.send_to( conn, channel, f"{who.nick}: current reply chance is {chance}", ) case ["chance", chance]: try: reply_chance = float(chance) except ValueError: log.error("Couldn't parse %r as a float", chance) return if not math.isnan(reply_chance): reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance) self.chain.set_reply_chance(channel, who.nick, reply_chance) self.send_to( conn, channel, f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}", ) case _: # command not recognized pass async def save(self): self.chain.commit() async def on_unload(self, conn: IrcProtocol): await self.save() PLUGIN_TYPE = Markov