Files
omnibot22/plugins/markov.py
Alek Ratzloff cc30df8706 Move markov nodes to be single strings
Separating strings by spaces is more memory-friendly than using tuples.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
2022-05-27 18:41:17 -07:00

246 lines
7.8 KiB
Python

import asyncio
from collections import defaultdict
import dataclasses
import json
import logging
from pathlib import Path
import random
from typing import Any, List, Mapping, Sequence
from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix
from omnibot.plugin import Plugin
log = logging.getLogger(__name__)
ALLCHAIN = "ALL!CHAIN"
def chain_inner_default() -> defaultdict[str | None, int]:
return defaultdict(int)
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
return defaultdict(chain_inner_default)
def windows(items: Sequence[Any], size: int):
if len(items) < size:
yield items
else:
for i in range(0, len(items) - size + 1):
yield items[i : i + size]
@dataclasses.dataclass
class Chain:
def __init__(self, order: int, path: Path):
self.order = order
self.path = path
self.__cache = chain_default()
self.__last_access = 0.0
@property
def last_access(self) -> float:
return self.__last_access
def add(self, text: str):
parts: List[Any] = text.strip().split()
if not parts:
return
self.__load()
for fragment in windows(parts + [None], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
self.__cache[" ".join(head)][tail] += 1
def get(self, key: str) -> dict[str | None, int]:
self.__last_access = asyncio.get_running_loop().time()
if self.__cache:
if key in self.__cache:
return self.__cache[key]
else:
raise KeyError(key)
else:
# Load cache, then return key
self.__load()
return self.__cache[key]
def set(self, key: str, value: Mapping[str | None, int]):
self.__last_access = asyncio.get_running_loop().time()
if not self.__cache:
# Attempt the cache before writing to it
self.__load()
self.__cache[key] = defaultdict(int, value)
def __load(self):
self.__last_access = asyncio.get_running_loop().time()
if self.__cache:
return
if not self.path.exists():
return
with open(self.path) as fp:
log.info("Loading markov chain from %s", self.path)
self.__cache = defaultdict(
chain_inner_default,
{
key: defaultdict(
int,
{
(None if not word else word): weight
for word, weight in value.items()
},
)
for key, value in json.load(fp).items()
},
)
def save(self, retain: bool = True):
if not self.__cache:
return
if retain:
log.info("Saving markov chain to %s", self.path)
else:
log.info("Saving markov chain to %s (not retaining)", self.path)
self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, "w") as fp:
json.dump(
{
key: {
("" if word is None else word): weight
for word, weight in value.items()
}
for key, value in self.__cache.items()
},
fp,
)
if not retain:
self.clear_cache()
def clear_cache(self):
self.__cache.clear()
def __bool__(self) -> bool:
return self.path.exists() or bool(self.__cache)
def generate(self) -> str:
self.__load()
if not self.__cache:
return ""
words: List[str] = []
node = random.choice(list(self.__cache.keys())).split(" ")
words += node
next: str | None = self.choose_next(" ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(" ".join(node))
return " ".join(words)
def choose_next(self, head: str) -> str | None:
self.__load()
choices = self.__cache[head]
words = list(choices.keys())
weights = list(choices.values())
if not words:
return None
return random.choices(words, weights)[0]
class Markov(Plugin):
def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs)
self.order = self.plugin_config.get("order", 1)
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
self.save_every = self.plugin_config.get("save_every", 300)
self.__chains = {}
self.__save_loop_task = None
self.__saving = asyncio.Lock()
async def on_load(self):
loop = asyncio.get_running_loop()
self.__save_loop_task = loop.create_task(self.__save_loop())
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip()
parts = line.split()
if not parts:
return
elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts)
elif line[0] != "!":
# ignore other commands
self.add(channel, who.nick, line)
def get_chain(self, channel: str, who: str) -> Chain:
if channel not in self.__chains:
self.__chains[channel] = {}
if who not in self.__chains[channel]:
path = self.data_path / channel / who
self.__chains[channel][who] = Chain(self.order, path)
return self.__chains[channel][who]
def add(self, channel: str, who: str, line: str):
if who == self.server_config.nick == who:
return
chain = self.get_chain(channel, who)
chain.add(line)
allchain = self.get_chain(channel, ALLCHAIN)
allchain.add(line)
def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
):
# handle markov commands
match parts[1:]:
case ["force"]:
chain = self.get_chain(channel, who.nick)
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]:
chain = self.get_chain(channel, who.nick)
if not chain:
return
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["all"]:
chain = self.get_chain(channel, ALLCHAIN)
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case _:
# command not recognized
pass
async def __save_loop(self):
while True:
log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
await asyncio.sleep(self.save_every)
retain_after = asyncio.get_running_loop().time() - self.save_every
await self.save(retain_after=retain_after)
async def save(self, retain_after: float | None = None):
async with self.__saving:
log.info("Saving markov chains")
for chains in self.__chains.values():
for chain in chains.values():
if retain_after is not None:
retain = chain.last_access > retain_after
else:
retain = True
chain.save(retain=retain)
log.info("Done")
async def on_unload(self, conn: IrcProtocol):
self.__save_loop_task.cancel()
await self.save()
PLUGIN_TYPE = Markov