Separating strings by spaces is more memory-friendly than using tuples. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
246 lines
7.8 KiB
Python
246 lines
7.8 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
import random
|
|
from typing import Any, List, Mapping, Sequence
|
|
|
|
from asyncirc.protocol import IrcProtocol
|
|
from irclib.parser import Prefix
|
|
|
|
from omnibot.plugin import Plugin
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
ALLCHAIN = "ALL!CHAIN"
|
|
|
|
|
|
def chain_inner_default() -> defaultdict[str | None, int]:
|
|
return defaultdict(int)
|
|
|
|
|
|
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
|
|
return defaultdict(chain_inner_default)
|
|
|
|
|
|
def windows(items: Sequence[Any], size: int):
|
|
if len(items) < size:
|
|
yield items
|
|
else:
|
|
for i in range(0, len(items) - size + 1):
|
|
yield items[i : i + size]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Chain:
|
|
def __init__(self, order: int, path: Path):
|
|
self.order = order
|
|
self.path = path
|
|
self.__cache = chain_default()
|
|
self.__last_access = 0.0
|
|
|
|
@property
|
|
def last_access(self) -> float:
|
|
return self.__last_access
|
|
|
|
def add(self, text: str):
|
|
parts: List[Any] = text.strip().split()
|
|
if not parts:
|
|
return
|
|
self.__load()
|
|
for fragment in windows(parts + [None], self.order + 1):
|
|
head = fragment[0:-1]
|
|
tail = fragment[-1]
|
|
self.__cache[" ".join(head)][tail] += 1
|
|
|
|
def get(self, key: str) -> dict[str | None, int]:
|
|
self.__last_access = asyncio.get_running_loop().time()
|
|
if self.__cache:
|
|
if key in self.__cache:
|
|
return self.__cache[key]
|
|
else:
|
|
raise KeyError(key)
|
|
else:
|
|
# Load cache, then return key
|
|
self.__load()
|
|
return self.__cache[key]
|
|
|
|
def set(self, key: str, value: Mapping[str | None, int]):
|
|
self.__last_access = asyncio.get_running_loop().time()
|
|
if not self.__cache:
|
|
# Attempt the cache before writing to it
|
|
self.__load()
|
|
self.__cache[key] = defaultdict(int, value)
|
|
|
|
def __load(self):
|
|
self.__last_access = asyncio.get_running_loop().time()
|
|
if self.__cache:
|
|
return
|
|
if not self.path.exists():
|
|
return
|
|
with open(self.path) as fp:
|
|
log.info("Loading markov chain from %s", self.path)
|
|
self.__cache = defaultdict(
|
|
chain_inner_default,
|
|
{
|
|
key: defaultdict(
|
|
int,
|
|
{
|
|
(None if not word else word): weight
|
|
for word, weight in value.items()
|
|
},
|
|
)
|
|
for key, value in json.load(fp).items()
|
|
},
|
|
)
|
|
|
|
def save(self, retain: bool = True):
|
|
if not self.__cache:
|
|
return
|
|
if retain:
|
|
log.info("Saving markov chain to %s", self.path)
|
|
else:
|
|
log.info("Saving markov chain to %s (not retaining)", self.path)
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(self.path, "w") as fp:
|
|
json.dump(
|
|
{
|
|
key: {
|
|
("" if word is None else word): weight
|
|
for word, weight in value.items()
|
|
}
|
|
for key, value in self.__cache.items()
|
|
},
|
|
fp,
|
|
)
|
|
if not retain:
|
|
self.clear_cache()
|
|
|
|
def clear_cache(self):
|
|
self.__cache.clear()
|
|
|
|
def __bool__(self) -> bool:
|
|
return self.path.exists() or bool(self.__cache)
|
|
|
|
def generate(self) -> str:
|
|
self.__load()
|
|
if not self.__cache:
|
|
return ""
|
|
|
|
words: List[str] = []
|
|
|
|
node = random.choice(list(self.__cache.keys())).split(" ")
|
|
words += node
|
|
next: str | None = self.choose_next(" ".join(node))
|
|
while next:
|
|
words += [next]
|
|
node = [*node[1:], next]
|
|
next = self.choose_next(" ".join(node))
|
|
return " ".join(words)
|
|
|
|
def choose_next(self, head: str) -> str | None:
|
|
self.__load()
|
|
choices = self.__cache[head]
|
|
words = list(choices.keys())
|
|
weights = list(choices.values())
|
|
if not words:
|
|
return None
|
|
return random.choices(words, weights)[0]
|
|
|
|
|
|
class Markov(Plugin):
|
|
def __init__(self, *args, **kwargs):
|
|
super(Markov, self).__init__(*args, **kwargs)
|
|
|
|
self.order = self.plugin_config.get("order", 1)
|
|
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
|
|
self.save_every = self.plugin_config.get("save_every", 300)
|
|
self.__chains = {}
|
|
self.__save_loop_task = None
|
|
self.__saving = asyncio.Lock()
|
|
|
|
async def on_load(self):
|
|
loop = asyncio.get_running_loop()
|
|
self.__save_loop_task = loop.create_task(self.__save_loop())
|
|
|
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
|
line = line.strip()
|
|
parts = line.split()
|
|
if not parts:
|
|
return
|
|
elif parts[0] == "!markov":
|
|
self.handle_command(conn, channel, who, parts)
|
|
elif line[0] != "!":
|
|
# ignore other commands
|
|
self.add(channel, who.nick, line)
|
|
|
|
def get_chain(self, channel: str, who: str) -> Chain:
|
|
if channel not in self.__chains:
|
|
self.__chains[channel] = {}
|
|
if who not in self.__chains[channel]:
|
|
path = self.data_path / channel / who
|
|
self.__chains[channel][who] = Chain(self.order, path)
|
|
return self.__chains[channel][who]
|
|
|
|
def add(self, channel: str, who: str, line: str):
|
|
if who == self.server_config.nick == who:
|
|
return
|
|
chain = self.get_chain(channel, who)
|
|
chain.add(line)
|
|
allchain = self.get_chain(channel, ALLCHAIN)
|
|
allchain.add(line)
|
|
|
|
def handle_command(
|
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
|
):
|
|
# handle markov commands
|
|
match parts[1:]:
|
|
case ["force"]:
|
|
chain = self.get_chain(channel, who.nick)
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["force", nick] | ["emulate", nick]:
|
|
chain = self.get_chain(channel, who.nick)
|
|
if not chain:
|
|
return
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["all"]:
|
|
chain = self.get_chain(channel, ALLCHAIN)
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case _:
|
|
# command not recognized
|
|
pass
|
|
|
|
async def __save_loop(self):
|
|
while True:
|
|
log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
|
|
await asyncio.sleep(self.save_every)
|
|
retain_after = asyncio.get_running_loop().time() - self.save_every
|
|
await self.save(retain_after=retain_after)
|
|
|
|
async def save(self, retain_after: float | None = None):
|
|
async with self.__saving:
|
|
log.info("Saving markov chains")
|
|
for chains in self.__chains.values():
|
|
for chain in chains.values():
|
|
if retain_after is not None:
|
|
retain = chain.last_access > retain_after
|
|
else:
|
|
retain = True
|
|
chain.save(retain=retain)
|
|
log.info("Done")
|
|
|
|
async def on_unload(self, conn: IrcProtocol):
|
|
self.__save_loop_task.cancel()
|
|
await self.save()
|
|
|
|
|
|
PLUGIN_TYPE = Markov
|