markov: Trying out a sqlite3-based model

This is a lot simpler from a concurrency perspective. Training values
can get committed to the database immediately, rather than in
long-running flat file batches.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
2022-06-23 12:00:33 -07:00
parent 737e032783
commit 086ba7706e

View File

@@ -6,6 +6,7 @@ import logging
import math
from pathlib import Path
import random
import sqlite3
from typing import Any, List, Mapping, Sequence
from asyncirc.protocol import IrcProtocol
@@ -34,6 +35,139 @@ def windows(items: Sequence[Any], size: int):
yield items[i : i + size]
class DbChain:
def __init__(self, order: int, path: Path):
self.order = order
self.path = path
self.db = sqlite3.connect(self.path)
def commit(self):
self.db.commit()
def execute(self, *args, **kwargs):
cursor = self.db.cursor()
cursor.execute(*args, **kwargs)
return list(cursor.fetchall())
def get_user_id(self, channel: str, nick: str) -> int | None:
if result := self.execute(
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
):
return result[0][0]
else:
return None
def ensure_user(self, channel: str, nick: str):
if self.get_user_id(channel, nick):
return
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
def ensure_key(self, channel: str, nick: str, key: str, next: str):
assert next is not None
self.ensure_user(channel, nick)
if next in self.get(channel, nick, key):
return
self.execute(
"""
INSERT INTO chain (user, value, weight, next)
VALUES (
(SELECT id FROM user WHERE channel = ? AND nick = ?),
?, 0, ?
)
""",
(channel, nick, key, next),
)
def add(self, channel: str, nick: str, text: str, commit=True):
parts: List[Any] = text.strip().split()
if not parts:
return
for fragment in windows(parts + [""], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
key = " ".join(head)
self.update_chain(channel, nick, key, tail)
if commit:
self.commit()
def update_chain(
self,
channel: str,
nick: str,
key: str,
next: str,
weight: int = 1,
):
self.ensure_key(channel, nick, key, next)
# Get if the key exists
self.execute(
"""
UPDATE chain
SET weight = weight + :weight
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
AND value = :key
AND next = :next
""",
{
"channel": channel,
"nick": nick,
"key": key,
"next": next,
"weight": weight,
},
)
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
cursor = self.db.cursor()
cursor.execute(
"""
SELECT next, weight
FROM chain
WHERE
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
AND value = ?
""",
(channel, nick, key),
)
return {next: weight for next, weight in cursor.fetchall()}
def generate(self, channel: str, nick: str) -> str | None:
user_id = self.get_user_id(channel, nick)
if not user_id:
return None
words: List[str] = []
cursor = self.db.cursor()
cursor.execute(
"""
SELECT value
FROM chain
WHERE user = ?
ORDER BY RANDOM()
LIMIT 1
""",
(user_id,),
)
node = cursor.fetchone()[0].split(" ")
words += node
next: str = self.choose_next(channel, nick, " ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(channel, nick, " ".join(node))
return " ".join(words)
def choose_next(self, channel: str, nick: str, head: str) -> str:
choices = self.get(channel, nick, head)
words = list(choices.keys())
weights = list(choices.values())
if not words:
return ""
return random.choices(words, weights)[0]
@dataclasses.dataclass
class Chain:
def __init__(self, order: int, chance: float, path: Path):