From 086ba7706e4f7814817a10f29bac3747acafa2d3 Mon Sep 17 00:00:00 2001 From: Alek Ratzloff Date: Thu, 23 Jun 2022 12:00:33 -0700 Subject: [PATCH] markov: Trying out a sqlite3-based model This is a lot simpler from a concurrency perspective. Training values can get committed to the database immediately, rather than in long-running flat file batches. Signed-off-by: Alek Ratzloff --- plugins/markov.py | 134 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/plugins/markov.py b/plugins/markov.py index 8a84bfc..f24ca7b 100644 --- a/plugins/markov.py +++ b/plugins/markov.py @@ -6,6 +6,7 @@ import logging import math from pathlib import Path import random +import sqlite3 from typing import Any, List, Mapping, Sequence from asyncirc.protocol import IrcProtocol @@ -34,6 +35,139 @@ def windows(items: Sequence[Any], size: int): yield items[i : i + size] +class DbChain: + def __init__(self, order: int, path: Path): + self.order = order + self.path = path + self.db = sqlite3.connect(self.path) + + def commit(self): + self.db.commit() + + def execute(self, *args, **kwargs): + cursor = self.db.cursor() + cursor.execute(*args, **kwargs) + return list(cursor.fetchall()) + + def get_user_id(self, channel: str, nick: str) -> int | None: + if result := self.execute( + "SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick) + ): + return result[0][0] + else: + return None + + def ensure_user(self, channel: str, nick: str): + if self.get_user_id(channel, nick): + return + self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick)) + + def ensure_key(self, channel: str, nick: str, key: str, next: str): + assert next is not None + + self.ensure_user(channel, nick) + if next in self.get(channel, nick, key): + return + self.execute( + """ + INSERT INTO chain (user, value, weight, next) + VALUES ( + (SELECT id FROM user WHERE channel = ? AND nick = ?), + ?, 0, ? + ) + """, + (channel, nick, key, next), + ) + + def add(self, channel: str, nick: str, text: str, commit=True): + parts: List[Any] = text.strip().split() + if not parts: + return + for fragment in windows(parts + [""], self.order + 1): + head = fragment[0:-1] + tail = fragment[-1] + key = " ".join(head) + self.update_chain(channel, nick, key, tail) + if commit: + self.commit() + + def update_chain( + self, + channel: str, + nick: str, + key: str, + next: str, + weight: int = 1, + ): + self.ensure_key(channel, nick, key, next) + # Get if the key exists + self.execute( + """ + UPDATE chain + SET weight = weight + :weight + WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick) + AND value = :key + AND next = :next + """, + { + "channel": channel, + "nick": nick, + "key": key, + "next": next, + "weight": weight, + }, + ) + + def get(self, channel: str, nick: str, key: str) -> dict[str, int]: + cursor = self.db.cursor() + cursor.execute( + """ + SELECT next, weight + FROM chain + WHERE + user = (SELECT id FROM user WHERE channel = ? AND nick = ?) + AND value = ? + """, + (channel, nick, key), + ) + return {next: weight for next, weight in cursor.fetchall()} + + def generate(self, channel: str, nick: str) -> str | None: + user_id = self.get_user_id(channel, nick) + if not user_id: + return None + words: List[str] = [] + + cursor = self.db.cursor() + cursor.execute( + """ + SELECT value + FROM chain + WHERE user = ? + ORDER BY RANDOM() + LIMIT 1 + """, + (user_id,), + ) + node = cursor.fetchone()[0].split(" ") + words += node + + next: str = self.choose_next(channel, nick, " ".join(node)) + while next: + words += [next] + node = [*node[1:], next] + next = self.choose_next(channel, nick, " ".join(node)) + return " ".join(words) + + def choose_next(self, channel: str, nick: str, head: str) -> str: + choices = self.get(channel, nick, head) + words = list(choices.keys()) + weights = list(choices.values()) + if not words: + return "" + return random.choices(words, weights)[0] + + @dataclasses.dataclass class Chain: def __init__(self, order: int, chance: float, path: Path):