From 8e639d50fa7f4091e607c0861c0ba8d8149db69d Mon Sep 17 00:00:00 2001 From: Alek Ratzloff Date: Thu, 23 Jun 2022 14:29:02 -0700 Subject: [PATCH] markov: Remove allchain The allchain has been a source of headaches because it takes up a lot of memory and slows everything down. However, with the new database model, we can generate markov sentences using all of the rows since they are a flat collection. This helps reduce disk space and increases the import speed significantly. Signed-off-by: Alek Ratzloff --- plugins/markov.py | 48 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/plugins/markov.py b/plugins/markov.py index 69a2b89..4aa7c80 100644 --- a/plugins/markov.py +++ b/plugins/markov.py @@ -4,7 +4,7 @@ import math from pathlib import Path import random import sqlite3 -from typing import Any, List, Sequence +from typing import Any, DefaultDict, List, Sequence from asyncirc.protocol import IrcProtocol from irclib.parser import Prefix @@ -13,7 +13,6 @@ from omnibot.plugin import Plugin log = logging.getLogger(__name__) -ALLCHAIN = "ALL!CHAIN" def chain_inner_default() -> defaultdict[str | None, int]: @@ -191,6 +190,48 @@ class Chain: return "" return random.choices(words, weights)[0] + def generate_all(self) -> str | None: + words: List[str] = [] + cursor = self.db.cursor() + cursor.execute( + """ + SELECT value + FROM chain + ORDER BY RANDOM() + LIMIT 1 + """, + ) + node = cursor.fetchone()[0].split(" ") + words += node + next: str = self.all_choose_next(" ".join(node)) + while next: + words += [next] + node = [*node[1:], next] + next = self.all_choose_next(" ".join(node)) + return " ".join(words) + + def all_choose_next(self, head: str) -> str: + cursor = self.db.cursor() + cursor.execute( + """ + SELECT next, weight + FROM chain + WHERE value = ? + ORDER BY RANDOM() + LIMIT 1 + """, + (head,), + ) + choices: DefaultDict[str, int] = defaultdict(lambda: 0) + # Collapse all choices by weight + for next, weight in cursor.fetchall(): + choices[next] += weight + if not choices: + return "" + words = list(choices.keys()) + weights = list(choices.values()) + return random.choices(words, weights)[0] + class Markov(Plugin): def __init__(self, *args, **kwargs): @@ -226,7 +267,6 @@ class Markov(Plugin): if who == self.server_config.nick: return self.chain.add(channel, who, line, commit) - self.chain.add(channel, ALLCHAIN, line, commit) def handle_command( self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] @@ -244,7 +284,7 @@ class Markov(Plugin): if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["all"]: - message = self.chain.generate(channel, ALLCHAIN) + message = self.chain.generate_all() if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["chance"]: