Finally settle on a good model for markov
If you don't use/access your chain every N seconds (300 by default), it will unload your chain from memory and save it to disk. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -1,12 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import partial
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
|
||||||
import random
|
import random
|
||||||
import shutil
|
from typing import Any, List, Mapping, Sequence
|
||||||
from typing import Any, List, Sequence
|
|
||||||
|
|
||||||
from asyncirc.protocol import IrcProtocol
|
from asyncirc.protocol import IrcProtocol
|
||||||
from irclib.parser import Prefix
|
from irclib.parser import Prefix
|
||||||
@@ -35,29 +34,103 @@ def windows(items: Sequence[Any], size: int):
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Chain:
|
class Chain:
|
||||||
order: int
|
def __init__(self, order: int, path: Path):
|
||||||
chain: defaultdict[
|
self.order = order
|
||||||
tuple[str, ...], defaultdict[str | None, int]
|
self.path = path
|
||||||
] = dataclasses.field(default_factory=chain_default)
|
self.__cache = chain_default()
|
||||||
|
self.__last_access = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_access(self) -> float:
|
||||||
|
return self.__last_access
|
||||||
|
|
||||||
def add(self, text: str):
|
def add(self, text: str):
|
||||||
parts: List[Any] = text.strip().split()
|
parts: List[Any] = text.strip().split()
|
||||||
if not parts:
|
if not parts:
|
||||||
return
|
return
|
||||||
if len(parts) < self.order:
|
self.__load()
|
||||||
self.chain[tuple(parts)][None] += 1
|
for fragment in windows(parts + [None], self.order + 1):
|
||||||
|
head = tuple(fragment[0:-1])
|
||||||
|
tail = fragment[-1]
|
||||||
|
self.__cache[head][tail] += 1
|
||||||
|
|
||||||
|
def get(self, key: tuple[str, ...]) -> dict[str | None, int]:
|
||||||
|
self.__last_access = asyncio.get_running_loop().time()
|
||||||
|
if self.__cache:
|
||||||
|
if key in self.__cache:
|
||||||
|
return self.__cache[key]
|
||||||
|
else:
|
||||||
|
raise KeyError(key)
|
||||||
else:
|
else:
|
||||||
for fragment in windows(parts + [None], self.order + 1):
|
# Load cache, then return key
|
||||||
head = tuple(fragment[0:-1])
|
self.__load()
|
||||||
tail = fragment[-1]
|
return self.__cache[key]
|
||||||
self.chain[head][tail] += 1
|
|
||||||
|
def set(self, key: tuple[str, ...], value: Mapping[str | None, int]):
|
||||||
|
self.__last_access = asyncio.get_running_loop().time()
|
||||||
|
if not self.__cache:
|
||||||
|
# Attempt the cache before writing to it
|
||||||
|
self.__load()
|
||||||
|
self.__cache[key] = defaultdict(int, value)
|
||||||
|
|
||||||
|
def __load(self):
|
||||||
|
self.__last_access = asyncio.get_running_loop().time()
|
||||||
|
if self.__cache:
|
||||||
|
return
|
||||||
|
if not self.path.exists():
|
||||||
|
return
|
||||||
|
with open(self.path) as fp:
|
||||||
|
log.info("Loading markov chain from %s", self.path)
|
||||||
|
self.__cache = defaultdict(
|
||||||
|
chain_inner_default,
|
||||||
|
{
|
||||||
|
tuple(key.split(" ")): defaultdict(
|
||||||
|
int,
|
||||||
|
{
|
||||||
|
(None if not word else word): weight
|
||||||
|
for word, weight in value.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for key, value in json.load(fp).items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self, retain: bool = True):
|
||||||
|
if not self.__cache:
|
||||||
|
return
|
||||||
|
if retain:
|
||||||
|
log.info("Saving markov chain to %s", self.path)
|
||||||
|
else:
|
||||||
|
log.info("Saving markov chain to %s (not retaining)", self.path)
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.path, "w") as fp:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
" ".join(key): {
|
||||||
|
("" if word is None else word): weight
|
||||||
|
for word, weight in value.items()
|
||||||
|
}
|
||||||
|
for key, value in self.__cache.items()
|
||||||
|
},
|
||||||
|
fp,
|
||||||
|
)
|
||||||
|
if not retain:
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.__cache.clear()
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return self.path.exists() or bool(self.__cache)
|
||||||
|
|
||||||
def generate(self) -> str:
|
def generate(self) -> str:
|
||||||
words: List[str] = []
|
self.__load()
|
||||||
if not self.chain:
|
if not self.__cache:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
node = random.choice(list(self.chain.keys()))
|
words: List[str] = []
|
||||||
|
|
||||||
|
node = random.choice(list(self.__cache.keys()))
|
||||||
words += node
|
words += node
|
||||||
next: str | None = self.choose_next(node)
|
next: str | None = self.choose_next(node)
|
||||||
while next:
|
while next:
|
||||||
@@ -68,9 +141,8 @@ class Chain:
|
|||||||
return " ".join(words)
|
return " ".join(words)
|
||||||
|
|
||||||
def choose_next(self, head: tuple[str, ...]) -> str | None:
|
def choose_next(self, head: tuple[str, ...]) -> str | None:
|
||||||
if head not in self.chain:
|
self.__load()
|
||||||
return None
|
choices = self.__cache[head]
|
||||||
choices = self.chain[head]
|
|
||||||
words = list(choices.keys())
|
words = list(choices.keys())
|
||||||
weights = list(choices.values())
|
weights = list(choices.values())
|
||||||
if not words:
|
if not words:
|
||||||
@@ -83,17 +155,15 @@ class Markov(Plugin):
|
|||||||
super(Markov, self).__init__(*args, **kwargs)
|
super(Markov, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.order = self.plugin_config.get("order", 1)
|
self.order = self.plugin_config.get("order", 1)
|
||||||
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl"))
|
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
|
||||||
self.backup_db_path = Path(str(self.db_path) + ".backup")
|
self.save_every = self.plugin_config.get("save_every", 300)
|
||||||
|
self.__chains = {}
|
||||||
|
self.__save_loop_task = None
|
||||||
|
self.__saving = asyncio.Lock()
|
||||||
|
|
||||||
# Load chain from data dir
|
async def on_load(self):
|
||||||
if self.db_path.exists():
|
loop = asyncio.get_running_loop()
|
||||||
with open(self.db_path, "rb") as fp:
|
self.__save_loop_task = loop.create_task(self.__save_loop())
|
||||||
self.chains = pickle.load(fp)
|
|
||||||
else:
|
|
||||||
self.chains = defaultdict(
|
|
||||||
partial(defaultdict, partial(Chain, order=self.order))
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
@@ -106,8 +176,17 @@ class Markov(Plugin):
|
|||||||
# ignore other commands
|
# ignore other commands
|
||||||
self.add(channel, who.nick, line)
|
self.add(channel, who.nick, line)
|
||||||
|
|
||||||
|
def get_chain(self, channel: str, who: str) -> Chain:
|
||||||
|
if channel not in self.__chains:
|
||||||
|
self.__chains[channel] = {}
|
||||||
|
if who not in self.__chains[channel]:
|
||||||
|
path = self.data_path / channel / who
|
||||||
|
self.__chains[channel][who] = Chain(self.order, path)
|
||||||
|
return self.__chains[channel][who]
|
||||||
|
|
||||||
def add(self, channel: str, who: str, line: str):
|
def add(self, channel: str, who: str, line: str):
|
||||||
self.chains[channel][who].add(line)
|
chain = self.get_chain(channel, who)
|
||||||
|
chain.add(line)
|
||||||
|
|
||||||
def handle_command(
|
def handle_command(
|
||||||
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
||||||
@@ -115,14 +194,14 @@ class Markov(Plugin):
|
|||||||
# handle markov commands
|
# handle markov commands
|
||||||
match parts[1:]:
|
match parts[1:]:
|
||||||
case ["force"]:
|
case ["force"]:
|
||||||
chain = self.chains[channel][who.nick]
|
chain = self.get_chain(channel, who.nick)
|
||||||
message = chain.generate()
|
message = chain.generate()
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
case ["force", nick] | ["emulate", nick]:
|
case ["force", nick] | ["emulate", nick]:
|
||||||
if nick not in self.chains[channel]:
|
chain = self.get_chain(channel, who.nick)
|
||||||
|
if not chain:
|
||||||
return
|
return
|
||||||
chain = self.chains[channel][nick]
|
|
||||||
message = chain.generate()
|
message = chain.generate()
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
@@ -130,19 +209,28 @@ class Markov(Plugin):
|
|||||||
# command not recognized
|
# command not recognized
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save(self):
|
async def __save_loop(self):
|
||||||
if self.db_path.exists():
|
while True:
|
||||||
log.info("Copying backup of markov chain to %s")
|
log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
|
||||||
shutil.copyfile(self.db_path, self.backup_db_path)
|
await asyncio.sleep(self.save_every)
|
||||||
|
retain_after = asyncio.get_running_loop().time() - self.save_every
|
||||||
|
await self.save(retain_after=retain_after)
|
||||||
|
|
||||||
|
async def save(self, retain_after: float | None = None):
|
||||||
|
async with self.__saving:
|
||||||
|
log.info("Saving markov chains")
|
||||||
|
for chains in self.__chains.values():
|
||||||
|
for chain in chains.values():
|
||||||
|
if retain_after is not None:
|
||||||
|
retain = chain.last_access > retain_after
|
||||||
|
else:
|
||||||
|
retain = True
|
||||||
|
chain.save(retain=retain)
|
||||||
log.info("Done")
|
log.info("Done")
|
||||||
|
|
||||||
log.info("Saving markov chain to %s", self.db_path)
|
|
||||||
with open(self.db_path, "wb") as fp:
|
|
||||||
pickle.dump(self.chains, fp)
|
|
||||||
log.info("Done")
|
|
||||||
|
|
||||||
async def on_unload(self, conn: IrcProtocol):
|
async def on_unload(self, conn: IrcProtocol):
|
||||||
self.save()
|
self.__save_loop_task.cancel()
|
||||||
|
await self.save()
|
||||||
|
|
||||||
|
|
||||||
PLUGIN_TYPE = Markov
|
PLUGIN_TYPE = Markov
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from omnibot.config import ServerConfig
|
|
||||||
from plugins.markov import Markov
|
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from omnibot.config import ServerConfig
|
||||||
|
from plugins.markov import Markov
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
async def main():
|
||||||
"""
|
"""
|
||||||
Hacky "load my IRC logs" script
|
Hacky "load my IRC logs" script
|
||||||
"""
|
"""
|
||||||
@@ -38,4 +40,8 @@ if __name__ == "__main__":
|
|||||||
name = mat["name"]
|
name = mat["name"]
|
||||||
message = mat["message"]
|
message = mat["message"]
|
||||||
plugin.add(channel, name, message)
|
plugin.add(channel, name, message)
|
||||||
plugin.save()
|
await plugin.save()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
Reference in New Issue
Block a user