markov: Trying out a sqlite3-based model
This is a lot simpler from a concurrency perspective. Training values can get committed to the database immediately, rather than in long-running flat file batches. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
import random
|
||||
import sqlite3
|
||||
from typing import Any, List, Mapping, Sequence
|
||||
|
||||
from asyncirc.protocol import IrcProtocol
|
||||
@@ -34,6 +35,139 @@ def windows(items: Sequence[Any], size: int):
|
||||
yield items[i : i + size]
|
||||
|
||||
|
||||
class DbChain:
|
||||
def __init__(self, order: int, path: Path):
|
||||
self.order = order
|
||||
self.path = path
|
||||
self.db = sqlite3.connect(self.path)
|
||||
|
||||
def commit(self):
|
||||
self.db.commit()
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(*args, **kwargs)
|
||||
return list(cursor.fetchall())
|
||||
|
||||
def get_user_id(self, channel: str, nick: str) -> int | None:
|
||||
if result := self.execute(
|
||||
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
|
||||
):
|
||||
return result[0][0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def ensure_user(self, channel: str, nick: str):
|
||||
if self.get_user_id(channel, nick):
|
||||
return
|
||||
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
|
||||
|
||||
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
||||
assert next is not None
|
||||
|
||||
self.ensure_user(channel, nick)
|
||||
if next in self.get(channel, nick, key):
|
||||
return
|
||||
self.execute(
|
||||
"""
|
||||
INSERT INTO chain (user, value, weight, next)
|
||||
VALUES (
|
||||
(SELECT id FROM user WHERE channel = ? AND nick = ?),
|
||||
?, 0, ?
|
||||
)
|
||||
""",
|
||||
(channel, nick, key, next),
|
||||
)
|
||||
|
||||
def add(self, channel: str, nick: str, text: str, commit=True):
|
||||
parts: List[Any] = text.strip().split()
|
||||
if not parts:
|
||||
return
|
||||
for fragment in windows(parts + [""], self.order + 1):
|
||||
head = fragment[0:-1]
|
||||
tail = fragment[-1]
|
||||
key = " ".join(head)
|
||||
self.update_chain(channel, nick, key, tail)
|
||||
if commit:
|
||||
self.commit()
|
||||
|
||||
def update_chain(
|
||||
self,
|
||||
channel: str,
|
||||
nick: str,
|
||||
key: str,
|
||||
next: str,
|
||||
weight: int = 1,
|
||||
):
|
||||
self.ensure_key(channel, nick, key, next)
|
||||
# Get if the key exists
|
||||
self.execute(
|
||||
"""
|
||||
UPDATE chain
|
||||
SET weight = weight + :weight
|
||||
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
|
||||
AND value = :key
|
||||
AND next = :next
|
||||
""",
|
||||
{
|
||||
"channel": channel,
|
||||
"nick": nick,
|
||||
"key": key,
|
||||
"next": next,
|
||||
"weight": weight,
|
||||
},
|
||||
)
|
||||
|
||||
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT next, weight
|
||||
FROM chain
|
||||
WHERE
|
||||
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
|
||||
AND value = ?
|
||||
""",
|
||||
(channel, nick, key),
|
||||
)
|
||||
return {next: weight for next, weight in cursor.fetchall()}
|
||||
|
||||
def generate(self, channel: str, nick: str) -> str | None:
|
||||
user_id = self.get_user_id(channel, nick)
|
||||
if not user_id:
|
||||
return None
|
||||
words: List[str] = []
|
||||
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM chain
|
||||
WHERE user = ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
node = cursor.fetchone()[0].split(" ")
|
||||
words += node
|
||||
|
||||
next: str = self.choose_next(channel, nick, " ".join(node))
|
||||
while next:
|
||||
words += [next]
|
||||
node = [*node[1:], next]
|
||||
next = self.choose_next(channel, nick, " ".join(node))
|
||||
return " ".join(words)
|
||||
|
||||
def choose_next(self, channel: str, nick: str, head: str) -> str:
|
||||
choices = self.get(channel, nick, head)
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
if not words:
|
||||
return ""
|
||||
return random.choices(words, weights)[0]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Chain:
|
||||
def __init__(self, order: int, chance: float, path: Path):
|
||||
|
||||
Reference in New Issue
Block a user