markov: Trying out a sqlite3-based model

This is a lot simpler from a concurrency perspective. Training values
can get committed to the database immediately, rather than in
long-running flat file batches.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
2022-06-23 12:00:33 -07:00
parent 737e032783
commit 086ba7706e

View File

@@ -6,6 +6,7 @@ import logging
import math import math
from pathlib import Path from pathlib import Path
import random import random
import sqlite3
from typing import Any, List, Mapping, Sequence from typing import Any, List, Mapping, Sequence
from asyncirc.protocol import IrcProtocol from asyncirc.protocol import IrcProtocol
@@ -34,6 +35,139 @@ def windows(items: Sequence[Any], size: int):
yield items[i : i + size] yield items[i : i + size]
class DbChain:
def __init__(self, order: int, path: Path):
self.order = order
self.path = path
self.db = sqlite3.connect(self.path)
def commit(self):
self.db.commit()
def execute(self, *args, **kwargs):
cursor = self.db.cursor()
cursor.execute(*args, **kwargs)
return list(cursor.fetchall())
def get_user_id(self, channel: str, nick: str) -> int | None:
if result := self.execute(
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
):
return result[0][0]
else:
return None
def ensure_user(self, channel: str, nick: str):
if self.get_user_id(channel, nick):
return
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
def ensure_key(self, channel: str, nick: str, key: str, next: str):
assert next is not None
self.ensure_user(channel, nick)
if next in self.get(channel, nick, key):
return
self.execute(
"""
INSERT INTO chain (user, value, weight, next)
VALUES (
(SELECT id FROM user WHERE channel = ? AND nick = ?),
?, 0, ?
)
""",
(channel, nick, key, next),
)
def add(self, channel: str, nick: str, text: str, commit=True):
parts: List[Any] = text.strip().split()
if not parts:
return
for fragment in windows(parts + [""], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
key = " ".join(head)
self.update_chain(channel, nick, key, tail)
if commit:
self.commit()
def update_chain(
self,
channel: str,
nick: str,
key: str,
next: str,
weight: int = 1,
):
self.ensure_key(channel, nick, key, next)
# Get if the key exists
self.execute(
"""
UPDATE chain
SET weight = weight + :weight
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
AND value = :key
AND next = :next
""",
{
"channel": channel,
"nick": nick,
"key": key,
"next": next,
"weight": weight,
},
)
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
cursor = self.db.cursor()
cursor.execute(
"""
SELECT next, weight
FROM chain
WHERE
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
AND value = ?
""",
(channel, nick, key),
)
return {next: weight for next, weight in cursor.fetchall()}
def generate(self, channel: str, nick: str) -> str | None:
user_id = self.get_user_id(channel, nick)
if not user_id:
return None
words: List[str] = []
cursor = self.db.cursor()
cursor.execute(
"""
SELECT value
FROM chain
WHERE user = ?
ORDER BY RANDOM()
LIMIT 1
""",
(user_id,),
)
node = cursor.fetchone()[0].split(" ")
words += node
next: str = self.choose_next(channel, nick, " ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(channel, nick, " ".join(node))
return " ".join(words)
def choose_next(self, channel: str, nick: str, head: str) -> str:
choices = self.get(channel, nick, head)
words = list(choices.keys())
weights = list(choices.values())
if not words:
return ""
return random.choices(words, weights)[0]
@dataclasses.dataclass @dataclasses.dataclass
class Chain: class Chain:
def __init__(self, order: int, chance: float, path: Path): def __init__(self, order: int, chance: float, path: Path):