Files
omnibot22/plugins/markov.py

251 lines
8.0 KiB
Python
Raw Normal View History

import asyncio
from collections import defaultdict
import dataclasses
import json
import logging
from pathlib import Path
import random
from typing import Any, List, Mapping, Sequence
from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix
from omnibot.plugin import Plugin
log = logging.getLogger(__name__)
ALLCHAIN = "ALL!CHAIN"
def chain_inner_default() -> defaultdict[str | None, int]:
return defaultdict(int)
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
return defaultdict(chain_inner_default)
def windows(items: Sequence[Any], size: int):
if len(items) < size:
yield items
else:
for i in range(0, len(items) - size + 1):
yield items[i : i + size]
@dataclasses.dataclass
class Chain:
def __init__(self, order: int, path: Path):
self.order = order
self.path = path
self.__cache = chain_default()
self.__last_access = 0.0
self.__dirty = False
@property
def last_access(self) -> float:
return self.__last_access
def add(self, text: str):
parts: List[Any] = text.strip().split()
if not parts:
return
self.__load()
for fragment in windows(parts + [None], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
self.__cache[" ".join(head)][tail] += 1
self.__dirty = True
def get(self, key: str) -> dict[str | None, int]:
self.__last_access = asyncio.get_running_loop().time()
if self.__cache:
if key in self.__cache:
return self.__cache[key]
else:
raise KeyError(key)
else:
# Load cache, then return key
self.__load()
return self.__cache[key]
def set(self, key: str, value: Mapping[str | None, int]):
self.__last_access = asyncio.get_running_loop().time()
if not self.__cache:
# Attempt the cache before writing to it
self.__load()
self.__cache[key] = defaultdict(int, value)
self.__dirty = True
def __load(self):
self.__last_access = asyncio.get_running_loop().time()
if self.__cache:
return
if not self.path.exists():
return
with open(self.path) as fp:
log.info("Loading markov chain from %s", self.path)
self.__cache = defaultdict(
chain_inner_default,
{
key: defaultdict(
int,
{
(None if not word else word): weight
for word, weight in value.items()
},
)
for key, value in json.load(fp).items()
},
)
self.__dirty = False
def save(self, retain: bool = True):
if not self.__cache:
return
if self.__dirty:
log.info("Saving markov chain to %s", self.path)
self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, "w") as fp:
json.dump(
{
key: {
("" if word is None else word): weight
for word, weight in value.items()
}
for key, value in self.__cache.items()
},
fp,
)
self.__dirty = False
if not retain:
log.debug("Pruning markov chain %s from memory", self.path)
self.clear_cache()
def clear_cache(self):
self.__cache.clear()
self.__dirty = False
def __bool__(self) -> bool:
return self.path.exists() or bool(self.__cache)
def generate(self) -> str:
self.__load()
if not self.__cache:
return ""
words: List[str] = []
node = random.choice(list(self.__cache.keys())).split(" ")
words += node
next: str | None = self.choose_next(" ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(" ".join(node))
return " ".join(words)
def choose_next(self, head: str) -> str | None:
self.__load()
choices = self.__cache[head]
words = list(choices.keys())
weights = list(choices.values())
if not words:
return None
return random.choices(words, weights)[0]
class Markov(Plugin):
def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs)
self.order = self.plugin_config.get("order", 1)
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
self.save_every = self.plugin_config.get("save_every", 300)
self.__chains = {}
self.__save_loop_task = None
self.__saving = asyncio.Lock()
async def on_load(self):
loop = asyncio.get_running_loop()
self.__save_loop_task = loop.create_task(self.__save_loop())
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip()
parts = line.split()
if not parts:
return
elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts)
elif line[0] != "!":
# ignore other commands
self.add(channel, who.nick, line)
def get_chain(self, channel: str, who: str) -> Chain:
if channel not in self.__chains:
self.__chains[channel] = {}
if who not in self.__chains[channel]:
path = self.data_path / channel / who
self.__chains[channel][who] = Chain(self.order, path)
return self.__chains[channel][who]
def add(self, channel: str, who: str, line: str):
if who == self.server_config.nick == who:
return
chain = self.get_chain(channel, who)
chain.add(line)
allchain = self.get_chain(channel, ALLCHAIN)
allchain.add(line)
def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
):
# handle markov commands
match parts[1:]:
case ["force"]:
chain = self.get_chain(channel, who.nick)
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]:
chain = self.get_chain(channel, who.nick)
if not chain:
return
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["all"]:
chain = self.get_chain(channel, ALLCHAIN)
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case _:
# command not recognized
pass
async def __save_loop(self):
while True:
log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
await asyncio.sleep(self.save_every)
retain_after = asyncio.get_running_loop().time() - self.save_every
await self.save(retain_after=retain_after)
async def save(self, retain_after: float | None = None):
async with self.__saving:
log.info("Saving markov chains")
for chains in self.__chains.values():
for chain in chains.values():
if retain_after is not None:
retain = chain.last_access > retain_after
else:
retain = True
chain.save(retain=retain)
log.info("Done")
async def on_unload(self, conn: IrcProtocol):
self.__save_loop_task.cancel()
await self.save()
PLUGIN_TYPE = Markov