Markov now uses a sqlite3 database instead of flat JSON files. This should significantly speed up saving time, plus reduce the amount of RAM that it uses. Saving and loading large JSON files was very slow and caused issues with other plugins, especially when messages were received. Additionally, in order to save RAM, a cache was used and periodically flushed when not used, adding some complications to the implementation. This has all been removed since things get committed on the fly with the database implementation. The main trade-off we have to make is the disk space used by the database. This is OK though, because disk space is cheap while RAM is not. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
283 lines
9.0 KiB
Python
283 lines
9.0 KiB
Python
from collections import defaultdict
|
|
import logging
|
|
import math
|
|
from pathlib import Path
|
|
import random
|
|
import sqlite3
|
|
from typing import Any, List, Sequence
|
|
|
|
from asyncirc.protocol import IrcProtocol
|
|
from irclib.parser import Prefix
|
|
|
|
from omnibot.plugin import Plugin
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
ALLCHAIN = "ALL!CHAIN"
|
|
|
|
|
|
def chain_inner_default() -> defaultdict[str | None, int]:
|
|
return defaultdict(int)
|
|
|
|
|
|
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
|
|
return defaultdict(chain_inner_default)
|
|
|
|
|
|
def windows(items: Sequence[Any], size: int):
|
|
if len(items) < size:
|
|
yield items
|
|
else:
|
|
for i in range(0, len(items) - size + 1):
|
|
yield items[i : i + size]
|
|
|
|
|
|
class Chain:
|
|
def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path):
|
|
self.order = order
|
|
self.reply_chance = reply_chance
|
|
self.path = path
|
|
self.db = sqlite3.connect(self.path)
|
|
|
|
# Run the initial database creation script
|
|
cursor = self.db.cursor()
|
|
with open(sql_path) as fp:
|
|
cursor.executescript(fp.read())
|
|
cursor.close()
|
|
self.db.commit()
|
|
|
|
def commit(self):
|
|
self.db.commit()
|
|
|
|
def execute(self, *args, **kwargs):
|
|
cursor = self.db.cursor()
|
|
cursor.execute(*args, **kwargs)
|
|
return list(cursor.fetchall())
|
|
|
|
def get_user_id(self, channel: str, nick: str) -> int | None:
|
|
if result := self.execute(
|
|
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
|
|
):
|
|
return result[0][0]
|
|
else:
|
|
return None
|
|
|
|
def get_reply_chance(self, channel: str, nick: str) -> float:
|
|
if result := self.execute(
|
|
"SELECT reply_chance FROM user WHERE channel = ? AND nick = ?",
|
|
(channel, nick),
|
|
):
|
|
return result[0][0]
|
|
else:
|
|
return 0.0
|
|
|
|
def set_reply_chance(self, channel: str, nick: str, chance: float):
|
|
self.ensure_user(channel, nick)
|
|
self.execute(
|
|
"UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?",
|
|
(chance, channel, nick),
|
|
)
|
|
|
|
def ensure_user(self, channel: str, nick: str):
|
|
if self.get_user_id(channel, nick):
|
|
return
|
|
self.execute(
|
|
"INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)",
|
|
(channel, nick, self.reply_chance),
|
|
)
|
|
|
|
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
|
assert next is not None
|
|
|
|
self.ensure_user(channel, nick)
|
|
if next in self.get(channel, nick, key):
|
|
return
|
|
self.execute(
|
|
"""
|
|
INSERT INTO chain (user, value, weight, next)
|
|
VALUES (
|
|
(SELECT id FROM user WHERE channel = ? AND nick = ?),
|
|
?, 0, ?
|
|
)
|
|
""",
|
|
(channel, nick, key, next),
|
|
)
|
|
|
|
def add(self, channel: str, nick: str, text: str, commit=True):
|
|
parts: List[Any] = text.strip().split()
|
|
if not parts:
|
|
return
|
|
for fragment in windows(parts + [""], self.order + 1):
|
|
head = fragment[0:-1]
|
|
tail = fragment[-1]
|
|
key = " ".join(head)
|
|
self.update_chain(channel, nick, key, tail)
|
|
if commit:
|
|
self.commit()
|
|
|
|
def update_chain(
|
|
self,
|
|
channel: str,
|
|
nick: str,
|
|
key: str,
|
|
next: str,
|
|
weight: int = 1,
|
|
):
|
|
self.ensure_key(channel, nick, key, next)
|
|
# Get if the key exists
|
|
self.execute(
|
|
"""
|
|
UPDATE chain
|
|
SET weight = weight + :weight
|
|
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
|
|
AND value = :key
|
|
AND next = :next
|
|
""",
|
|
{
|
|
"channel": channel,
|
|
"nick": nick,
|
|
"key": key,
|
|
"next": next,
|
|
"weight": weight,
|
|
},
|
|
)
|
|
|
|
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT next, weight
|
|
FROM chain
|
|
WHERE
|
|
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
|
|
AND value = ?
|
|
""",
|
|
(channel, nick, key),
|
|
)
|
|
return {next: weight for next, weight in cursor.fetchall()}
|
|
|
|
def generate(self, channel: str, nick: str) -> str | None:
|
|
user_id = self.get_user_id(channel, nick)
|
|
if not user_id:
|
|
return None
|
|
words: List[str] = []
|
|
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT value
|
|
FROM chain
|
|
WHERE user = ?
|
|
ORDER BY RANDOM()
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
)
|
|
node = cursor.fetchone()[0].split(" ")
|
|
words += node
|
|
|
|
next: str = self.choose_next(channel, nick, " ".join(node))
|
|
while next:
|
|
words += [next]
|
|
node = [*node[1:], next]
|
|
next = self.choose_next(channel, nick, " ".join(node))
|
|
return " ".join(words)
|
|
|
|
def choose_next(self, channel: str, nick: str, head: str) -> str:
|
|
choices = self.get(channel, nick, head)
|
|
words = list(choices.keys())
|
|
weights = list(choices.values())
|
|
if not words:
|
|
return ""
|
|
return random.choices(words, weights)[0]
|
|
|
|
|
|
class Markov(Plugin):
|
|
def __init__(self, *args, **kwargs):
|
|
super(Markov, self).__init__(*args, **kwargs)
|
|
|
|
self.order = int(self.plugin_config.get("order", 1))
|
|
self.data_path = Path(
|
|
self.plugin_config.get("data_path", "data/markov/markov.db")
|
|
)
|
|
self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql"))
|
|
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
|
|
self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path)
|
|
|
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
|
line = line.strip()
|
|
parts = line.split()
|
|
if not parts or who.nick == self.nick:
|
|
return
|
|
elif parts[0] == "!markov":
|
|
self.handle_command(conn, channel, who, parts)
|
|
elif line[0] != "!":
|
|
# ignore other commands
|
|
self.add(channel, who.nick, line)
|
|
# also, maybe generate a sentence
|
|
chosen = random.random()
|
|
reply_chance = self.chain.get_reply_chance(channel, who.nick)
|
|
if chosen <= reply_chance:
|
|
message = self.chain.generate(channel, who.nick)
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
|
|
def add(self, channel: str, who: str, line: str, commit=True):
|
|
if who == self.server_config.nick:
|
|
return
|
|
self.chain.add(channel, who, line, commit)
|
|
self.chain.add(channel, ALLCHAIN, line, commit)
|
|
|
|
def handle_command(
|
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
|
):
|
|
# handle markov commands
|
|
match parts[1:]:
|
|
case ["force"]:
|
|
message = self.chain.generate(channel, who.nick)
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["force", nick] | ["emulate", nick]:
|
|
if not self.chain.get_user_id(channel, nick):
|
|
return
|
|
message = self.chain.generate(channel, nick)
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["all"]:
|
|
message = self.chain.generate(channel, ALLCHAIN)
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["chance"]:
|
|
chance = self.chain.get_reply_chance(channel, who.nick)
|
|
self.send_to(
|
|
conn,
|
|
channel,
|
|
f"{who.nick}: current reply chance is {chance}",
|
|
)
|
|
case ["chance", chance]:
|
|
try:
|
|
reply_chance = float(chance)
|
|
except ValueError:
|
|
log.error("Couldn't parse %r as a float", chance)
|
|
return
|
|
if not math.isnan(reply_chance):
|
|
reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance)
|
|
self.chain.set_reply_chance(channel, who.nick, reply_chance)
|
|
self.send_to(
|
|
conn,
|
|
channel,
|
|
f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}",
|
|
)
|
|
case _:
|
|
# command not recognized
|
|
pass
|
|
|
|
async def save(self):
|
|
self.chain.commit()
|
|
|
|
async def on_unload(self, conn: IrcProtocol):
|
|
await self.save()
|
|
|
|
|
|
PLUGIN_TYPE = Markov
|