Finally settle on a good model for markov
If you don't use/access your chain every N seconds (300 by default), it will unload your chain from memory and save it to disk. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import random
|
||||
import shutil
|
||||
from typing import Any, List, Sequence
|
||||
from typing import Any, List, Mapping, Sequence
|
||||
|
||||
from asyncirc.protocol import IrcProtocol
|
||||
from irclib.parser import Prefix
|
||||
@@ -35,29 +34,103 @@ def windows(items: Sequence[Any], size: int):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Chain:
|
||||
order: int
|
||||
chain: defaultdict[
|
||||
tuple[str, ...], defaultdict[str | None, int]
|
||||
] = dataclasses.field(default_factory=chain_default)
|
||||
def __init__(self, order: int, path: Path):
|
||||
self.order = order
|
||||
self.path = path
|
||||
self.__cache = chain_default()
|
||||
self.__last_access = 0.0
|
||||
|
||||
@property
|
||||
def last_access(self) -> float:
|
||||
return self.__last_access
|
||||
|
||||
def add(self, text: str):
|
||||
parts: List[Any] = text.strip().split()
|
||||
if not parts:
|
||||
return
|
||||
if len(parts) < self.order:
|
||||
self.chain[tuple(parts)][None] += 1
|
||||
self.__load()
|
||||
for fragment in windows(parts + [None], self.order + 1):
|
||||
head = tuple(fragment[0:-1])
|
||||
tail = fragment[-1]
|
||||
self.__cache[head][tail] += 1
|
||||
|
||||
def get(self, key: tuple[str, ...]) -> dict[str | None, int]:
|
||||
self.__last_access = asyncio.get_running_loop().time()
|
||||
if self.__cache:
|
||||
if key in self.__cache:
|
||||
return self.__cache[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
for fragment in windows(parts + [None], self.order + 1):
|
||||
head = tuple(fragment[0:-1])
|
||||
tail = fragment[-1]
|
||||
self.chain[head][tail] += 1
|
||||
# Load cache, then return key
|
||||
self.__load()
|
||||
return self.__cache[key]
|
||||
|
||||
def set(self, key: tuple[str, ...], value: Mapping[str | None, int]):
|
||||
self.__last_access = asyncio.get_running_loop().time()
|
||||
if not self.__cache:
|
||||
# Attempt the cache before writing to it
|
||||
self.__load()
|
||||
self.__cache[key] = defaultdict(int, value)
|
||||
|
||||
def __load(self):
|
||||
self.__last_access = asyncio.get_running_loop().time()
|
||||
if self.__cache:
|
||||
return
|
||||
if not self.path.exists():
|
||||
return
|
||||
with open(self.path) as fp:
|
||||
log.info("Loading markov chain from %s", self.path)
|
||||
self.__cache = defaultdict(
|
||||
chain_inner_default,
|
||||
{
|
||||
tuple(key.split(" ")): defaultdict(
|
||||
int,
|
||||
{
|
||||
(None if not word else word): weight
|
||||
for word, weight in value.items()
|
||||
},
|
||||
)
|
||||
for key, value in json.load(fp).items()
|
||||
},
|
||||
)
|
||||
|
||||
def save(self, retain: bool = True):
|
||||
if not self.__cache:
|
||||
return
|
||||
if retain:
|
||||
log.info("Saving markov chain to %s", self.path)
|
||||
else:
|
||||
log.info("Saving markov chain to %s (not retaining)", self.path)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.path, "w") as fp:
|
||||
json.dump(
|
||||
{
|
||||
" ".join(key): {
|
||||
("" if word is None else word): weight
|
||||
for word, weight in value.items()
|
||||
}
|
||||
for key, value in self.__cache.items()
|
||||
},
|
||||
fp,
|
||||
)
|
||||
if not retain:
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.__cache.clear()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.path.exists() or bool(self.__cache)
|
||||
|
||||
def generate(self) -> str:
|
||||
words: List[str] = []
|
||||
if not self.chain:
|
||||
self.__load()
|
||||
if not self.__cache:
|
||||
return ""
|
||||
|
||||
node = random.choice(list(self.chain.keys()))
|
||||
words: List[str] = []
|
||||
|
||||
node = random.choice(list(self.__cache.keys()))
|
||||
words += node
|
||||
next: str | None = self.choose_next(node)
|
||||
while next:
|
||||
@@ -68,9 +141,8 @@ class Chain:
|
||||
return " ".join(words)
|
||||
|
||||
def choose_next(self, head: tuple[str, ...]) -> str | None:
|
||||
if head not in self.chain:
|
||||
return None
|
||||
choices = self.chain[head]
|
||||
self.__load()
|
||||
choices = self.__cache[head]
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
if not words:
|
||||
@@ -83,17 +155,15 @@ class Markov(Plugin):
|
||||
super(Markov, self).__init__(*args, **kwargs)
|
||||
|
||||
self.order = self.plugin_config.get("order", 1)
|
||||
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl"))
|
||||
self.backup_db_path = Path(str(self.db_path) + ".backup")
|
||||
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
|
||||
self.save_every = self.plugin_config.get("save_every", 300)
|
||||
self.__chains = {}
|
||||
self.__save_loop_task = None
|
||||
self.__saving = asyncio.Lock()
|
||||
|
||||
# Load chain from data dir
|
||||
if self.db_path.exists():
|
||||
with open(self.db_path, "rb") as fp:
|
||||
self.chains = pickle.load(fp)
|
||||
else:
|
||||
self.chains = defaultdict(
|
||||
partial(defaultdict, partial(Chain, order=self.order))
|
||||
)
|
||||
async def on_load(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
self.__save_loop_task = loop.create_task(self.__save_loop())
|
||||
|
||||
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
||||
line = line.strip()
|
||||
@@ -106,8 +176,17 @@ class Markov(Plugin):
|
||||
# ignore other commands
|
||||
self.add(channel, who.nick, line)
|
||||
|
||||
def get_chain(self, channel: str, who: str) -> Chain:
|
||||
if channel not in self.__chains:
|
||||
self.__chains[channel] = {}
|
||||
if who not in self.__chains[channel]:
|
||||
path = self.data_path / channel / who
|
||||
self.__chains[channel][who] = Chain(self.order, path)
|
||||
return self.__chains[channel][who]
|
||||
|
||||
def add(self, channel: str, who: str, line: str):
|
||||
self.chains[channel][who].add(line)
|
||||
chain = self.get_chain(channel, who)
|
||||
chain.add(line)
|
||||
|
||||
def handle_command(
|
||||
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
||||
@@ -115,14 +194,14 @@ class Markov(Plugin):
|
||||
# handle markov commands
|
||||
match parts[1:]:
|
||||
case ["force"]:
|
||||
chain = self.chains[channel][who.nick]
|
||||
chain = self.get_chain(channel, who.nick)
|
||||
message = chain.generate()
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
case ["force", nick] | ["emulate", nick]:
|
||||
if nick not in self.chains[channel]:
|
||||
chain = self.get_chain(channel, who.nick)
|
||||
if not chain:
|
||||
return
|
||||
chain = self.chains[channel][nick]
|
||||
message = chain.generate()
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
@@ -130,19 +209,28 @@ class Markov(Plugin):
|
||||
# command not recognized
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
if self.db_path.exists():
|
||||
log.info("Copying backup of markov chain to %s")
|
||||
shutil.copyfile(self.db_path, self.backup_db_path)
|
||||
async def __save_loop(self):
|
||||
while True:
|
||||
log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
|
||||
await asyncio.sleep(self.save_every)
|
||||
retain_after = asyncio.get_running_loop().time() - self.save_every
|
||||
await self.save(retain_after=retain_after)
|
||||
|
||||
async def save(self, retain_after: float | None = None):
|
||||
async with self.__saving:
|
||||
log.info("Saving markov chains")
|
||||
for chains in self.__chains.values():
|
||||
for chain in chains.values():
|
||||
if retain_after is not None:
|
||||
retain = chain.last_access > retain_after
|
||||
else:
|
||||
retain = True
|
||||
chain.save(retain=retain)
|
||||
log.info("Done")
|
||||
|
||||
log.info("Saving markov chain to %s", self.db_path)
|
||||
with open(self.db_path, "wb") as fp:
|
||||
pickle.dump(self.chains, fp)
|
||||
log.info("Done")
|
||||
|
||||
async def on_unload(self, conn: IrcProtocol):
|
||||
self.save()
|
||||
self.__save_loop_task.cancel()
|
||||
await self.save()
|
||||
|
||||
|
||||
PLUGIN_TYPE = Markov
|
||||
|
||||
Reference in New Issue
Block a user