markov: Trying out a sqlite3-based model
This is a lot simpler from a concurrency perspective. Training values can get committed to the database immediately, rather than in long-running flat file batches. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
|
import sqlite3
|
||||||
from typing import Any, List, Mapping, Sequence
|
from typing import Any, List, Mapping, Sequence
|
||||||
|
|
||||||
from asyncirc.protocol import IrcProtocol
|
from asyncirc.protocol import IrcProtocol
|
||||||
@@ -34,6 +35,139 @@ def windows(items: Sequence[Any], size: int):
|
|||||||
yield items[i : i + size]
|
yield items[i : i + size]
|
||||||
|
|
||||||
|
|
||||||
|
class DbChain:
|
||||||
|
def __init__(self, order: int, path: Path):
|
||||||
|
self.order = order
|
||||||
|
self.path = path
|
||||||
|
self.db = sqlite3.connect(self.path)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
cursor = self.db.cursor()
|
||||||
|
cursor.execute(*args, **kwargs)
|
||||||
|
return list(cursor.fetchall())
|
||||||
|
|
||||||
|
def get_user_id(self, channel: str, nick: str) -> int | None:
|
||||||
|
if result := self.execute(
|
||||||
|
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
|
||||||
|
):
|
||||||
|
return result[0][0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def ensure_user(self, channel: str, nick: str):
|
||||||
|
if self.get_user_id(channel, nick):
|
||||||
|
return
|
||||||
|
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
|
||||||
|
|
||||||
|
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
||||||
|
assert next is not None
|
||||||
|
|
||||||
|
self.ensure_user(channel, nick)
|
||||||
|
if next in self.get(channel, nick, key):
|
||||||
|
return
|
||||||
|
self.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO chain (user, value, weight, next)
|
||||||
|
VALUES (
|
||||||
|
(SELECT id FROM user WHERE channel = ? AND nick = ?),
|
||||||
|
?, 0, ?
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
(channel, nick, key, next),
|
||||||
|
)
|
||||||
|
|
||||||
|
def add(self, channel: str, nick: str, text: str, commit=True):
|
||||||
|
parts: List[Any] = text.strip().split()
|
||||||
|
if not parts:
|
||||||
|
return
|
||||||
|
for fragment in windows(parts + [""], self.order + 1):
|
||||||
|
head = fragment[0:-1]
|
||||||
|
tail = fragment[-1]
|
||||||
|
key = " ".join(head)
|
||||||
|
self.update_chain(channel, nick, key, tail)
|
||||||
|
if commit:
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
def update_chain(
|
||||||
|
self,
|
||||||
|
channel: str,
|
||||||
|
nick: str,
|
||||||
|
key: str,
|
||||||
|
next: str,
|
||||||
|
weight: int = 1,
|
||||||
|
):
|
||||||
|
self.ensure_key(channel, nick, key, next)
|
||||||
|
# Get if the key exists
|
||||||
|
self.execute(
|
||||||
|
"""
|
||||||
|
UPDATE chain
|
||||||
|
SET weight = weight + :weight
|
||||||
|
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
|
||||||
|
AND value = :key
|
||||||
|
AND next = :next
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"channel": channel,
|
||||||
|
"nick": nick,
|
||||||
|
"key": key,
|
||||||
|
"next": next,
|
||||||
|
"weight": weight,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
|
||||||
|
cursor = self.db.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT next, weight
|
||||||
|
FROM chain
|
||||||
|
WHERE
|
||||||
|
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
|
||||||
|
AND value = ?
|
||||||
|
""",
|
||||||
|
(channel, nick, key),
|
||||||
|
)
|
||||||
|
return {next: weight for next, weight in cursor.fetchall()}
|
||||||
|
|
||||||
|
def generate(self, channel: str, nick: str) -> str | None:
|
||||||
|
user_id = self.get_user_id(channel, nick)
|
||||||
|
if not user_id:
|
||||||
|
return None
|
||||||
|
words: List[str] = []
|
||||||
|
|
||||||
|
cursor = self.db.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT value
|
||||||
|
FROM chain
|
||||||
|
WHERE user = ?
|
||||||
|
ORDER BY RANDOM()
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
node = cursor.fetchone()[0].split(" ")
|
||||||
|
words += node
|
||||||
|
|
||||||
|
next: str = self.choose_next(channel, nick, " ".join(node))
|
||||||
|
while next:
|
||||||
|
words += [next]
|
||||||
|
node = [*node[1:], next]
|
||||||
|
next = self.choose_next(channel, nick, " ".join(node))
|
||||||
|
return " ".join(words)
|
||||||
|
|
||||||
|
def choose_next(self, channel: str, nick: str, head: str) -> str:
|
||||||
|
choices = self.get(channel, nick, head)
|
||||||
|
words = list(choices.keys())
|
||||||
|
weights = list(choices.values())
|
||||||
|
if not words:
|
||||||
|
return ""
|
||||||
|
return random.choices(words, weights)[0]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Chain:
|
class Chain:
|
||||||
def __init__(self, order: int, chance: float, path: Path):
|
def __init__(self, order: int, chance: float, path: Path):
|
||||||
|
|||||||
Reference in New Issue
Block a user