This is basically just a shorthand, but also abstracts away adding a line to a markov chain Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
149 lines
4.4 KiB
Python
149 lines
4.4 KiB
Python
from collections import defaultdict
|
|
import dataclasses
|
|
from functools import partial
|
|
import logging
|
|
from pathlib import Path
|
|
import pickle
|
|
import random
|
|
import shutil
|
|
from typing import Any, List, Sequence
|
|
|
|
from asyncirc.protocol import IrcProtocol
|
|
from irclib.parser import Prefix
|
|
|
|
from omnibot.plugin import Plugin
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def chain_inner_default() -> defaultdict[str | None, int]:
|
|
return defaultdict(int)
|
|
|
|
|
|
def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]:
|
|
return defaultdict(chain_inner_default)
|
|
|
|
|
|
def windows(items: Sequence[Any], size: int):
|
|
if len(items) < size:
|
|
yield items
|
|
else:
|
|
for i in range(0, len(items) - size + 1):
|
|
yield items[i : i + size]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Chain:
|
|
order: int
|
|
chain: defaultdict[
|
|
tuple[str, ...], defaultdict[str | None, int]
|
|
] = dataclasses.field(default_factory=chain_default)
|
|
|
|
def add(self, text: str):
|
|
parts: List[Any] = text.strip().split()
|
|
if not parts:
|
|
return
|
|
if len(parts) < self.order:
|
|
self.chain[tuple(parts)][None] += 1
|
|
else:
|
|
for fragment in windows(parts + [None], self.order + 1):
|
|
head = tuple(fragment[0:-1])
|
|
tail = fragment[-1]
|
|
self.chain[head][tail] += 1
|
|
|
|
def generate(self) -> str:
|
|
words: List[str] = []
|
|
if not self.chain:
|
|
return ""
|
|
|
|
node = random.choice(list(self.chain.keys()))
|
|
words += node
|
|
next: str | None = self.choose_next(node)
|
|
while next:
|
|
words += [next]
|
|
node = (*node[1:], next)
|
|
next = self.choose_next(node)
|
|
|
|
return " ".join(words)
|
|
|
|
def choose_next(self, head: tuple[str, ...]) -> str | None:
|
|
if head not in self.chain:
|
|
return None
|
|
choices = self.chain[head]
|
|
words = list(choices.keys())
|
|
weights = list(choices.values())
|
|
if not words:
|
|
return None
|
|
return random.choices(words, weights)[0]
|
|
|
|
|
|
class Markov(Plugin):
|
|
def __init__(self, *args, **kwargs):
|
|
super(Markov, self).__init__(*args, **kwargs)
|
|
|
|
self.order = self.plugin_config.get("order", 1)
|
|
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl"))
|
|
self.backup_db_path = Path(str(self.db_path) + ".backup")
|
|
|
|
# Load chain from data dir
|
|
if self.db_path.exists():
|
|
with open(self.db_path, "rb") as fp:
|
|
self.chains = pickle.load(fp)
|
|
else:
|
|
self.chains = defaultdict(
|
|
partial(defaultdict, partial(Chain, order=self.order))
|
|
)
|
|
|
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
|
line = line.strip()
|
|
parts = line.split()
|
|
if not parts:
|
|
return
|
|
elif parts[0] == "!markov":
|
|
self.handle_command(conn, channel, who, parts)
|
|
elif line[0] != "!":
|
|
# ignore other commands
|
|
self.add(channel, who.nick, line)
|
|
|
|
def add(self, channel: str, who: str, line: str):
|
|
self.chains[channel][who].add(line)
|
|
|
|
def handle_command(
|
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
|
):
|
|
# handle markov commands
|
|
match parts[1:]:
|
|
case ["force"]:
|
|
chain = self.chains[channel][who.nick]
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["force", nick] | ["emulate", nick]:
|
|
if nick not in self.chains[channel]:
|
|
return
|
|
chain = self.chains[channel][nick]
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case _:
|
|
# command not recognized
|
|
pass
|
|
|
|
def save(self):
|
|
if self.db_path.exists():
|
|
log.info("Copying backup of markov chain to %s")
|
|
shutil.copyfile(self.db_path, self.backup_db_path)
|
|
log.info("Done")
|
|
|
|
log.info("Saving markov chain to %s", self.db_path)
|
|
with open(self.db_path, "wb") as fp:
|
|
pickle.dump(self.chains, fp)
|
|
log.info("Done")
|
|
|
|
async def on_unload(self, conn: IrcProtocol):
|
|
self.save()
|
|
|
|
|
|
PLUGIN_TYPE = Markov
|