markov: Finish up chain database implementation
Markov now uses a sqlite3 database instead of flat JSON files. This should significantly speed up saving time, plus reduce the amount of RAM that it uses. Saving and loading large JSON files was very slow and caused issues with other plugins, especially when messages were received. Additionally, in order to save RAM, a cache was used and periodically flushed when not used, adding some complications to the implementation. This has all been removed since things get committed on the fly with the database implementation. The main trade-off we have to make is the disk space used by the database. This is OK though, because disk space is cheap while RAM is not. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -1,13 +1,10 @@
|
|||||||
import asyncio
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import dataclasses
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any, List, Mapping, Sequence
|
from typing import Any, List, Sequence
|
||||||
|
|
||||||
from asyncirc.protocol import IrcProtocol
|
from asyncirc.protocol import IrcProtocol
|
||||||
from irclib.parser import Prefix
|
from irclib.parser import Prefix
|
||||||
@@ -35,12 +32,20 @@ def windows(items: Sequence[Any], size: int):
|
|||||||
yield items[i : i + size]
|
yield items[i : i + size]
|
||||||
|
|
||||||
|
|
||||||
class DbChain:
|
class Chain:
|
||||||
def __init__(self, order: int, path: Path):
|
def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path):
|
||||||
self.order = order
|
self.order = order
|
||||||
|
self.reply_chance = reply_chance
|
||||||
self.path = path
|
self.path = path
|
||||||
self.db = sqlite3.connect(self.path)
|
self.db = sqlite3.connect(self.path)
|
||||||
|
|
||||||
|
# Run the initial database creation script
|
||||||
|
cursor = self.db.cursor()
|
||||||
|
with open(sql_path) as fp:
|
||||||
|
cursor.executescript(fp.read())
|
||||||
|
cursor.close()
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
def commit(self):
|
def commit(self):
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
@@ -57,10 +62,29 @@ class DbChain:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_reply_chance(self, channel: str, nick: str) -> float:
|
||||||
|
if result := self.execute(
|
||||||
|
"SELECT reply_chance FROM user WHERE channel = ? AND nick = ?",
|
||||||
|
(channel, nick),
|
||||||
|
):
|
||||||
|
return result[0][0]
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def set_reply_chance(self, channel: str, nick: str, chance: float):
|
||||||
|
self.ensure_user(channel, nick)
|
||||||
|
self.execute(
|
||||||
|
"UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?",
|
||||||
|
(chance, channel, nick),
|
||||||
|
)
|
||||||
|
|
||||||
def ensure_user(self, channel: str, nick: str):
|
def ensure_user(self, channel: str, nick: str):
|
||||||
if self.get_user_id(channel, nick):
|
if self.get_user_id(channel, nick):
|
||||||
return
|
return
|
||||||
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
|
self.execute(
|
||||||
|
"INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)",
|
||||||
|
(channel, nick, self.reply_chance),
|
||||||
|
)
|
||||||
|
|
||||||
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
||||||
assert next is not None
|
assert next is not None
|
||||||
@@ -168,172 +192,22 @@ class DbChain:
|
|||||||
return random.choices(words, weights)[0]
|
return random.choices(words, weights)[0]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Chain:
|
|
||||||
def __init__(self, order: int, chance: float, path: Path):
|
|
||||||
self.order = order
|
|
||||||
self.__reply_chance = chance
|
|
||||||
self.path = path
|
|
||||||
self.__cache = chain_default()
|
|
||||||
self.__last_access = 0.0
|
|
||||||
self.__dirty = False
|
|
||||||
|
|
||||||
def __touch(self):
|
|
||||||
self.__last_access = asyncio.get_running_loop().time()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reply_chance(self) -> float:
|
|
||||||
self.__load()
|
|
||||||
return self.__reply_chance
|
|
||||||
|
|
||||||
@reply_chance.setter
|
|
||||||
def reply_chance(self, val: float):
|
|
||||||
if not (isinstance(val, float) or isinstance(val, int)):
|
|
||||||
return NotImplemented
|
|
||||||
self.__load()
|
|
||||||
self.__reply_chance = val
|
|
||||||
self.__dirty = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def last_access(self) -> float:
|
|
||||||
return self.__last_access
|
|
||||||
|
|
||||||
def add(self, text: str):
|
|
||||||
parts: List[Any] = text.strip().split()
|
|
||||||
if not parts:
|
|
||||||
return
|
|
||||||
self.__touch()
|
|
||||||
self.__load()
|
|
||||||
self.__dirty = True
|
|
||||||
for fragment in windows(parts + [None], self.order + 1):
|
|
||||||
head = fragment[0:-1]
|
|
||||||
tail = fragment[-1]
|
|
||||||
self.__cache[" ".join(head)][tail] += 1
|
|
||||||
|
|
||||||
def get(self, key: str) -> dict[str | None, int]:
|
|
||||||
if self.__cache:
|
|
||||||
if key in self.__cache:
|
|
||||||
self.__touch()
|
|
||||||
return self.__cache[key]
|
|
||||||
else:
|
|
||||||
raise KeyError(key)
|
|
||||||
else:
|
|
||||||
# Load cache, then return key
|
|
||||||
self.__load()
|
|
||||||
self.__touch()
|
|
||||||
return self.__cache[key]
|
|
||||||
|
|
||||||
def set(self, key: str, value: Mapping[str | None, int]):
|
|
||||||
self.__touch()
|
|
||||||
if not self.__cache:
|
|
||||||
# Attempt the cache before writing to it
|
|
||||||
self.__load()
|
|
||||||
self.__cache[key] = defaultdict(int, value)
|
|
||||||
self.__dirty = True
|
|
||||||
|
|
||||||
def __load(self):
|
|
||||||
self.__touch()
|
|
||||||
if self.__cache:
|
|
||||||
return
|
|
||||||
if not self.path.exists():
|
|
||||||
return
|
|
||||||
with open(self.path) as fp:
|
|
||||||
log.info("Loading markov chain from %s", self.path)
|
|
||||||
obj = json.load(fp)
|
|
||||||
|
|
||||||
# Load the save object
|
|
||||||
self.__reply_chance = obj["reply_chance"]
|
|
||||||
self.__cache = defaultdict(
|
|
||||||
chain_inner_default,
|
|
||||||
{
|
|
||||||
key: defaultdict(
|
|
||||||
int,
|
|
||||||
{
|
|
||||||
(None if not word else word): weight
|
|
||||||
for word, weight in value.items()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for key, value in obj["chain"].items()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.__dirty = False
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
if not self.__cache:
|
|
||||||
return
|
|
||||||
if self.__dirty:
|
|
||||||
log.info("Saving markov chain to %s", self.path)
|
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
# Build the save object
|
|
||||||
obj = {
|
|
||||||
"reply_chance": self.__reply_chance,
|
|
||||||
"chain": {
|
|
||||||
key: {
|
|
||||||
("" if word is None else word): weight
|
|
||||||
for word, weight in value.items()
|
|
||||||
}
|
|
||||||
for key, value in self.__cache.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
with open(self.path, "w") as fp:
|
|
||||||
json.dump(obj, fp)
|
|
||||||
self.__dirty = False
|
|
||||||
else:
|
|
||||||
log.info("Chain %s is not dirty, not saving", self.path)
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.__cache.clear()
|
|
||||||
self.__dirty = False
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return self.path.exists() or bool(self.__cache)
|
|
||||||
|
|
||||||
def generate(self) -> str:
|
|
||||||
self.__load()
|
|
||||||
if not self.__cache:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
words: List[str] = []
|
|
||||||
|
|
||||||
node = random.choice(list(self.__cache.keys())).split(" ")
|
|
||||||
words += node
|
|
||||||
next: str | None = self.choose_next(" ".join(node))
|
|
||||||
while next:
|
|
||||||
words += [next]
|
|
||||||
node = [*node[1:], next]
|
|
||||||
next = self.choose_next(" ".join(node))
|
|
||||||
return " ".join(words)
|
|
||||||
|
|
||||||
def choose_next(self, head: str) -> str | None:
|
|
||||||
self.__load()
|
|
||||||
choices = self.__cache[head]
|
|
||||||
words = list(choices.keys())
|
|
||||||
weights = list(choices.values())
|
|
||||||
if not words:
|
|
||||||
return None
|
|
||||||
return random.choices(words, weights)[0]
|
|
||||||
|
|
||||||
|
|
||||||
class Markov(Plugin):
|
class Markov(Plugin):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Markov, self).__init__(*args, **kwargs)
|
super(Markov, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.order = int(self.plugin_config.get("order", 1))
|
self.order = int(self.plugin_config.get("order", 1))
|
||||||
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
|
self.data_path = Path(
|
||||||
self.save_every = int(self.plugin_config.get("save_every", 1800))
|
self.plugin_config.get("data_path", "data/markov/markov.db")
|
||||||
|
)
|
||||||
|
self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql"))
|
||||||
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
|
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
|
||||||
self.__chains = {}
|
self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path)
|
||||||
self.__save_loop_task = None
|
|
||||||
self.__saving = asyncio.Lock()
|
|
||||||
|
|
||||||
async def on_load(self):
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
self.__save_loop_task = loop.create_task(self.__save_loop())
|
|
||||||
|
|
||||||
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
parts = line.split()
|
parts = line.split()
|
||||||
if not parts:
|
if not parts or who.nick == self.nick:
|
||||||
return
|
return
|
||||||
elif parts[0] == "!markov":
|
elif parts[0] == "!markov":
|
||||||
self.handle_command(conn, channel, who, parts)
|
self.handle_command(conn, channel, who, parts)
|
||||||
@@ -342,27 +216,17 @@ class Markov(Plugin):
|
|||||||
self.add(channel, who.nick, line)
|
self.add(channel, who.nick, line)
|
||||||
# also, maybe generate a sentence
|
# also, maybe generate a sentence
|
||||||
chosen = random.random()
|
chosen = random.random()
|
||||||
chain = self.get_chain(channel, who.nick)
|
reply_chance = self.chain.get_reply_chance(channel, who.nick)
|
||||||
if chosen <= chain.reply_chance:
|
if chosen <= reply_chance:
|
||||||
message = chain.generate()
|
message = self.chain.generate(channel, who.nick)
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
|
|
||||||
def get_chain(self, channel: str, who: str) -> Chain:
|
def add(self, channel: str, who: str, line: str, commit=True):
|
||||||
if channel not in self.__chains:
|
|
||||||
self.__chains[channel] = {}
|
|
||||||
if who not in self.__chains[channel]:
|
|
||||||
path = self.data_path / channel / who
|
|
||||||
self.__chains[channel][who] = Chain(self.order, self.reply_chance, path)
|
|
||||||
return self.__chains[channel][who]
|
|
||||||
|
|
||||||
def add(self, channel: str, who: str, line: str):
|
|
||||||
if who == self.server_config.nick:
|
if who == self.server_config.nick:
|
||||||
return
|
return
|
||||||
chain = self.get_chain(channel, who)
|
self.chain.add(channel, who, line, commit)
|
||||||
chain.add(line)
|
self.chain.add(channel, ALLCHAIN, line, commit)
|
||||||
allchain = self.get_chain(channel, ALLCHAIN)
|
|
||||||
allchain.add(line)
|
|
||||||
|
|
||||||
def handle_command(
|
def handle_command(
|
||||||
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
||||||
@@ -370,87 +234,48 @@ class Markov(Plugin):
|
|||||||
# handle markov commands
|
# handle markov commands
|
||||||
match parts[1:]:
|
match parts[1:]:
|
||||||
case ["force"]:
|
case ["force"]:
|
||||||
chain = self.get_chain(channel, who.nick)
|
message = self.chain.generate(channel, who.nick)
|
||||||
message = chain.generate()
|
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
case ["force", nick] | ["emulate", nick]:
|
case ["force", nick] | ["emulate", nick]:
|
||||||
chain = self.get_chain(channel, nick)
|
if not self.chain.get_user_id(channel, nick):
|
||||||
if not chain:
|
|
||||||
return
|
return
|
||||||
message = chain.generate()
|
message = self.chain.generate(channel, nick)
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
case ["all"]:
|
case ["all"]:
|
||||||
chain = self.get_chain(channel, ALLCHAIN)
|
message = self.chain.generate(channel, ALLCHAIN)
|
||||||
message = chain.generate()
|
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
case ["chance"]:
|
case ["chance"]:
|
||||||
chain = self.get_chain(channel, who.nick)
|
chance = self.chain.get_reply_chance(channel, who.nick)
|
||||||
self.send_to(
|
self.send_to(
|
||||||
conn,
|
conn,
|
||||||
channel,
|
channel,
|
||||||
f"{who.nick}: current reply chance is {chain.reply_chance}",
|
f"{who.nick}: current reply chance is {chance}",
|
||||||
)
|
)
|
||||||
case ["chance", chance]:
|
case ["chance", chance]:
|
||||||
chain = self.get_chain(channel, who.nick)
|
|
||||||
try:
|
try:
|
||||||
reply_chance = float(chance)
|
reply_chance = float(chance)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
log.error("Couldn't parse %r as a float", chance)
|
log.error("Couldn't parse %r as a float", chance)
|
||||||
return
|
return
|
||||||
if not math.isnan(reply_chance):
|
if not math.isnan(reply_chance):
|
||||||
chain.reply_chance = min(
|
reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance)
|
||||||
max(float(reply_chance), 0.0), self.reply_chance
|
self.chain.set_reply_chance(channel, who.nick, reply_chance)
|
||||||
)
|
|
||||||
self.send_to(
|
self.send_to(
|
||||||
conn,
|
conn,
|
||||||
channel,
|
channel,
|
||||||
f"{who.nick}: reply chance set to {chain.reply_chance}",
|
f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}",
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
# command not recognized
|
# command not recognized
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def __save_loop(self):
|
async def save(self):
|
||||||
while True:
|
self.chain.commit()
|
||||||
log.debug("Saving markov chains in %s seconds", self.save_every)
|
|
||||||
await asyncio.sleep(self.save_every)
|
|
||||||
retain_after = asyncio.get_running_loop().time() - self.save_every
|
|
||||||
await self.save(retain_after=retain_after)
|
|
||||||
|
|
||||||
async def save(self, retain_after: float | None = None):
|
|
||||||
async with self.__saving:
|
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
|
||||||
|
|
||||||
log.info("Saving markov chains")
|
|
||||||
coros = []
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
# ProcessPoolExecutor is an explicit decision I've made to use,
|
|
||||||
# because it allows us to save in a different process, with
|
|
||||||
# different memory, and simultaneously clear it if it needs to be
|
|
||||||
# cleared.
|
|
||||||
with ProcessPoolExecutor() as pool:
|
|
||||||
for chains in self.__chains.values():
|
|
||||||
for chain in chains.values():
|
|
||||||
# Start the save in a new process, in a new task.
|
|
||||||
log.debug("Starting process to save %s", chain.path)
|
|
||||||
coro = loop.run_in_executor(pool, chain.save)
|
|
||||||
coros += [coro]
|
|
||||||
# Prune
|
|
||||||
retain = True
|
|
||||||
if retain_after is not None:
|
|
||||||
retain = chain.last_access > retain_after
|
|
||||||
if not retain:
|
|
||||||
log.info("Pruning markov chain %s from memory", chain.path)
|
|
||||||
chain.clear_cache()
|
|
||||||
if coros:
|
|
||||||
await asyncio.gather(*coros)
|
|
||||||
log.info("Done")
|
|
||||||
|
|
||||||
async def on_unload(self, conn: IrcProtocol):
|
async def on_unload(self, conn: IrcProtocol):
|
||||||
self.__save_loop_task.cancel()
|
|
||||||
await self.save()
|
await self.save()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ async def main():
|
|||||||
name = mat["name"]
|
name = mat["name"]
|
||||||
message = mat["message"].strip()
|
message = mat["message"].strip()
|
||||||
if name != server_config.nick and message and message[0] != "!":
|
if name != server_config.nick and message and message[0] != "!":
|
||||||
plugin.add(channel, name, message)
|
plugin.add(channel, name, message, commit=False)
|
||||||
await plugin.save()
|
await plugin.save()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user