Files
omnibot22/plugins/markov.py
Alek Ratzloff 8c4bb5ac60 markov: Finish up chain database implementation
Markov now uses a sqlite3 database instead of flat JSON files. This
should significantly speed up saving time, plus reduce the amount of RAM
that it uses. Saving and loading large JSON files was very slow and
caused issues with other plugins, especially when messages were
received. Additionally, in order to save RAM, a cache was used and
periodically flushed when not used, adding some complications to the
implementation. This has all been removed since things get committed on
the fly with the database implementation. The main trade-off we have to
make is the disk space used by the database. This is OK though, because
disk space is cheap while RAM is not.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
2022-06-23 12:29:19 -07:00

283 lines
9.0 KiB
Python

from collections import defaultdict
import logging
import math
from pathlib import Path
import random
import sqlite3
from typing import Any, List, Sequence
from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix
from omnibot.plugin import Plugin
log = logging.getLogger(__name__)
ALLCHAIN = "ALL!CHAIN"
def chain_inner_default() -> defaultdict[str | None, int]:
return defaultdict(int)
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
return defaultdict(chain_inner_default)
def windows(items: Sequence[Any], size: int):
if len(items) < size:
yield items
else:
for i in range(0, len(items) - size + 1):
yield items[i : i + size]
class Chain:
def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path):
self.order = order
self.reply_chance = reply_chance
self.path = path
self.db = sqlite3.connect(self.path)
# Run the initial database creation script
cursor = self.db.cursor()
with open(sql_path) as fp:
cursor.executescript(fp.read())
cursor.close()
self.db.commit()
def commit(self):
self.db.commit()
def execute(self, *args, **kwargs):
cursor = self.db.cursor()
cursor.execute(*args, **kwargs)
return list(cursor.fetchall())
def get_user_id(self, channel: str, nick: str) -> int | None:
if result := self.execute(
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
):
return result[0][0]
else:
return None
def get_reply_chance(self, channel: str, nick: str) -> float:
if result := self.execute(
"SELECT reply_chance FROM user WHERE channel = ? AND nick = ?",
(channel, nick),
):
return result[0][0]
else:
return 0.0
def set_reply_chance(self, channel: str, nick: str, chance: float):
self.ensure_user(channel, nick)
self.execute(
"UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?",
(chance, channel, nick),
)
def ensure_user(self, channel: str, nick: str):
if self.get_user_id(channel, nick):
return
self.execute(
"INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)",
(channel, nick, self.reply_chance),
)
def ensure_key(self, channel: str, nick: str, key: str, next: str):
assert next is not None
self.ensure_user(channel, nick)
if next in self.get(channel, nick, key):
return
self.execute(
"""
INSERT INTO chain (user, value, weight, next)
VALUES (
(SELECT id FROM user WHERE channel = ? AND nick = ?),
?, 0, ?
)
""",
(channel, nick, key, next),
)
def add(self, channel: str, nick: str, text: str, commit=True):
parts: List[Any] = text.strip().split()
if not parts:
return
for fragment in windows(parts + [""], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
key = " ".join(head)
self.update_chain(channel, nick, key, tail)
if commit:
self.commit()
def update_chain(
self,
channel: str,
nick: str,
key: str,
next: str,
weight: int = 1,
):
self.ensure_key(channel, nick, key, next)
# Get if the key exists
self.execute(
"""
UPDATE chain
SET weight = weight + :weight
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
AND value = :key
AND next = :next
""",
{
"channel": channel,
"nick": nick,
"key": key,
"next": next,
"weight": weight,
},
)
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
cursor = self.db.cursor()
cursor.execute(
"""
SELECT next, weight
FROM chain
WHERE
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
AND value = ?
""",
(channel, nick, key),
)
return {next: weight for next, weight in cursor.fetchall()}
def generate(self, channel: str, nick: str) -> str | None:
user_id = self.get_user_id(channel, nick)
if not user_id:
return None
words: List[str] = []
cursor = self.db.cursor()
cursor.execute(
"""
SELECT value
FROM chain
WHERE user = ?
ORDER BY RANDOM()
LIMIT 1
""",
(user_id,),
)
node = cursor.fetchone()[0].split(" ")
words += node
next: str = self.choose_next(channel, nick, " ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(channel, nick, " ".join(node))
return " ".join(words)
def choose_next(self, channel: str, nick: str, head: str) -> str:
choices = self.get(channel, nick, head)
words = list(choices.keys())
weights = list(choices.values())
if not words:
return ""
return random.choices(words, weights)[0]
class Markov(Plugin):
def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs)
self.order = int(self.plugin_config.get("order", 1))
self.data_path = Path(
self.plugin_config.get("data_path", "data/markov/markov.db")
)
self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql"))
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path)
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip()
parts = line.split()
if not parts or who.nick == self.nick:
return
elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts)
elif line[0] != "!":
# ignore other commands
self.add(channel, who.nick, line)
# also, maybe generate a sentence
chosen = random.random()
reply_chance = self.chain.get_reply_chance(channel, who.nick)
if chosen <= reply_chance:
message = self.chain.generate(channel, who.nick)
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
def add(self, channel: str, who: str, line: str, commit=True):
if who == self.server_config.nick:
return
self.chain.add(channel, who, line, commit)
self.chain.add(channel, ALLCHAIN, line, commit)
def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
):
# handle markov commands
match parts[1:]:
case ["force"]:
message = self.chain.generate(channel, who.nick)
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]:
if not self.chain.get_user_id(channel, nick):
return
message = self.chain.generate(channel, nick)
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["all"]:
message = self.chain.generate(channel, ALLCHAIN)
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["chance"]:
chance = self.chain.get_reply_chance(channel, who.nick)
self.send_to(
conn,
channel,
f"{who.nick}: current reply chance is {chance}",
)
case ["chance", chance]:
try:
reply_chance = float(chance)
except ValueError:
log.error("Couldn't parse %r as a float", chance)
return
if not math.isnan(reply_chance):
reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance)
self.chain.set_reply_chance(channel, who.nick, reply_chance)
self.send_to(
conn,
channel,
f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}",
)
case _:
# command not recognized
pass
async def save(self):
self.chain.commit()
async def on_unload(self, conn: IrcProtocol):
await self.save()
PLUGIN_TYPE = Markov