Finally settle on a good model for markov

If you don't use/access your chain every N seconds (300 by default), it
will unload your chain from memory and save it to disk.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
2022-05-26 20:59:06 -07:00
parent a30588111b
commit a4958d371e
2 changed files with 142 additions and 48 deletions

View File

@@ -1,12 +1,11 @@
import asyncio
from collections import defaultdict from collections import defaultdict
import dataclasses import dataclasses
from functools import partial import json
import logging import logging
from pathlib import Path from pathlib import Path
import pickle
import random import random
import shutil from typing import Any, List, Mapping, Sequence
from typing import Any, List, Sequence
from asyncirc.protocol import IrcProtocol from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix from irclib.parser import Prefix
@@ -35,29 +34,103 @@ def windows(items: Sequence[Any], size: int):
@dataclasses.dataclass @dataclasses.dataclass
class Chain: class Chain:
order: int def __init__(self, order: int, path: Path):
chain: defaultdict[ self.order = order
tuple[str, ...], defaultdict[str | None, int] self.path = path
] = dataclasses.field(default_factory=chain_default) self.__cache = chain_default()
self.__last_access = 0.0
@property
def last_access(self) -> float:
return self.__last_access
def add(self, text: str): def add(self, text: str):
parts: List[Any] = text.strip().split() parts: List[Any] = text.strip().split()
if not parts: if not parts:
return return
if len(parts) < self.order: self.__load()
self.chain[tuple(parts)][None] += 1 for fragment in windows(parts + [None], self.order + 1):
head = tuple(fragment[0:-1])
tail = fragment[-1]
self.__cache[head][tail] += 1
def get(self, key: tuple[str, ...]) -> dict[str | None, int]:
self.__last_access = asyncio.get_running_loop().time()
if self.__cache:
if key in self.__cache:
return self.__cache[key]
else:
raise KeyError(key)
else: else:
for fragment in windows(parts + [None], self.order + 1): # Load cache, then return key
head = tuple(fragment[0:-1]) self.__load()
tail = fragment[-1] return self.__cache[key]
self.chain[head][tail] += 1
def set(self, key: tuple[str, ...], value: Mapping[str | None, int]):
self.__last_access = asyncio.get_running_loop().time()
if not self.__cache:
# Attempt the cache before writing to it
self.__load()
self.__cache[key] = defaultdict(int, value)
def __load(self):
self.__last_access = asyncio.get_running_loop().time()
if self.__cache:
return
if not self.path.exists():
return
with open(self.path) as fp:
log.info("Loading markov chain from %s", self.path)
self.__cache = defaultdict(
chain_inner_default,
{
tuple(key.split(" ")): defaultdict(
int,
{
(None if not word else word): weight
for word, weight in value.items()
},
)
for key, value in json.load(fp).items()
},
)
def save(self, retain: bool = True):
if not self.__cache:
return
if retain:
log.info("Saving markov chain to %s", self.path)
else:
log.info("Saving markov chain to %s (not retaining)", self.path)
self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, "w") as fp:
json.dump(
{
" ".join(key): {
("" if word is None else word): weight
for word, weight in value.items()
}
for key, value in self.__cache.items()
},
fp,
)
if not retain:
self.clear()
def clear(self):
self.__cache.clear()
def __bool__(self) -> bool:
return self.path.exists() or bool(self.__cache)
def generate(self) -> str: def generate(self) -> str:
words: List[str] = [] self.__load()
if not self.chain: if not self.__cache:
return "" return ""
node = random.choice(list(self.chain.keys())) words: List[str] = []
node = random.choice(list(self.__cache.keys()))
words += node words += node
next: str | None = self.choose_next(node) next: str | None = self.choose_next(node)
while next: while next:
@@ -68,9 +141,8 @@ class Chain:
return " ".join(words) return " ".join(words)
def choose_next(self, head: tuple[str, ...]) -> str | None: def choose_next(self, head: tuple[str, ...]) -> str | None:
if head not in self.chain: self.__load()
return None choices = self.__cache[head]
choices = self.chain[head]
words = list(choices.keys()) words = list(choices.keys())
weights = list(choices.values()) weights = list(choices.values())
if not words: if not words:
@@ -83,17 +155,15 @@ class Markov(Plugin):
super(Markov, self).__init__(*args, **kwargs) super(Markov, self).__init__(*args, **kwargs)
self.order = self.plugin_config.get("order", 1) self.order = self.plugin_config.get("order", 1)
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl")) self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
self.backup_db_path = Path(str(self.db_path) + ".backup") self.save_every = self.plugin_config.get("save_every", 300)
self.__chains = {}
self.__save_loop_task = None
self.__saving = asyncio.Lock()
# Load chain from data dir async def on_load(self):
if self.db_path.exists(): loop = asyncio.get_running_loop()
with open(self.db_path, "rb") as fp: self.__save_loop_task = loop.create_task(self.__save_loop())
self.chains = pickle.load(fp)
else:
self.chains = defaultdict(
partial(defaultdict, partial(Chain, order=self.order))
)
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip() line = line.strip()
@@ -106,8 +176,17 @@ class Markov(Plugin):
# ignore other commands # ignore other commands
self.add(channel, who.nick, line) self.add(channel, who.nick, line)
def get_chain(self, channel: str, who: str) -> Chain:
if channel not in self.__chains:
self.__chains[channel] = {}
if who not in self.__chains[channel]:
path = self.data_path / channel / who
self.__chains[channel][who] = Chain(self.order, path)
return self.__chains[channel][who]
def add(self, channel: str, who: str, line: str): def add(self, channel: str, who: str, line: str):
self.chains[channel][who].add(line) chain = self.get_chain(channel, who)
chain.add(line)
def handle_command( def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str] self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
@@ -115,14 +194,14 @@ class Markov(Plugin):
# handle markov commands # handle markov commands
match parts[1:]: match parts[1:]:
case ["force"]: case ["force"]:
chain = self.chains[channel][who.nick] chain = self.get_chain(channel, who.nick)
message = chain.generate() message = chain.generate()
if message: if message:
self.send_to(conn, channel, f"{who.nick}: {message}") self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]: case ["force", nick] | ["emulate", nick]:
if nick not in self.chains[channel]: chain = self.get_chain(channel, who.nick)
if not chain:
return return
chain = self.chains[channel][nick]
message = chain.generate() message = chain.generate()
if message: if message:
self.send_to(conn, channel, f"{who.nick}: {message}") self.send_to(conn, channel, f"{who.nick}: {message}")
@@ -130,19 +209,28 @@ class Markov(Plugin):
# command not recognized # command not recognized
pass pass
def save(self): async def __save_loop(self):
if self.db_path.exists(): while True:
log.info("Copying backup of markov chain to %s") log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
shutil.copyfile(self.db_path, self.backup_db_path) await asyncio.sleep(self.save_every)
retain_after = asyncio.get_running_loop().time() - self.save_every
await self.save(retain_after=retain_after)
async def save(self, retain_after: float | None = None):
async with self.__saving:
log.info("Saving markov chains")
for chains in self.__chains.values():
for chain in chains.values():
if retain_after is not None:
retain = chain.last_access > retain_after
else:
retain = True
chain.save(retain=retain)
log.info("Done") log.info("Done")
log.info("Saving markov chain to %s", self.db_path)
with open(self.db_path, "wb") as fp:
pickle.dump(self.chains, fp)
log.info("Done")
async def on_unload(self, conn: IrcProtocol): async def on_unload(self, conn: IrcProtocol):
self.save() self.__save_loop_task.cancel()
await self.save()
PLUGIN_TYPE = Markov PLUGIN_TYPE = Markov

View File

@@ -1,11 +1,13 @@
from omnibot.config import ServerConfig
from plugins.markov import Markov
import logging import logging
import asyncio
import sys import sys
import re import re
from omnibot.config import ServerConfig
from plugins.markov import Markov
if __name__ == "__main__":
async def main():
""" """
Hacky "load my IRC logs" script Hacky "load my IRC logs" script
""" """
@@ -38,4 +40,8 @@ if __name__ == "__main__":
name = mat["name"] name = mat["name"]
message = mat["message"] message = mat["message"]
plugin.add(channel, name, message) plugin.add(channel, name, message)
plugin.save() await plugin.save()
if __name__ == "__main__":
asyncio.run(main())