Files
omnibot22/plugins/markov.py
Alek Ratzloff 85d48d368c Add Markov.add function
This is basically just a shorthand, but also abstracts away adding a
line to a markov chain

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
2022-05-25 19:18:37 -07:00

149 lines
4.4 KiB
Python

from collections import defaultdict
import dataclasses
from functools import partial
import logging
from pathlib import Path
import pickle
import random
import shutil
from typing import Any, List, Sequence
from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix
from omnibot.plugin import Plugin
log = logging.getLogger(__name__)
def chain_inner_default() -> defaultdict[str | None, int]:
return defaultdict(int)
def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]:
return defaultdict(chain_inner_default)
def windows(items: Sequence[Any], size: int):
if len(items) < size:
yield items
else:
for i in range(0, len(items) - size + 1):
yield items[i : i + size]
@dataclasses.dataclass
class Chain:
order: int
chain: defaultdict[
tuple[str, ...], defaultdict[str | None, int]
] = dataclasses.field(default_factory=chain_default)
def add(self, text: str):
parts: List[Any] = text.strip().split()
if not parts:
return
if len(parts) < self.order:
self.chain[tuple(parts)][None] += 1
else:
for fragment in windows(parts + [None], self.order + 1):
head = tuple(fragment[0:-1])
tail = fragment[-1]
self.chain[head][tail] += 1
def generate(self) -> str:
words: List[str] = []
if not self.chain:
return ""
node = random.choice(list(self.chain.keys()))
words += node
next: str | None = self.choose_next(node)
while next:
words += [next]
node = (*node[1:], next)
next = self.choose_next(node)
return " ".join(words)
def choose_next(self, head: tuple[str, ...]) -> str | None:
if head not in self.chain:
return None
choices = self.chain[head]
words = list(choices.keys())
weights = list(choices.values())
if not words:
return None
return random.choices(words, weights)[0]
class Markov(Plugin):
def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs)
self.order = self.plugin_config.get("order", 1)
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl"))
self.backup_db_path = Path(str(self.db_path) + ".backup")
# Load chain from data dir
if self.db_path.exists():
with open(self.db_path, "rb") as fp:
self.chains = pickle.load(fp)
else:
self.chains = defaultdict(
partial(defaultdict, partial(Chain, order=self.order))
)
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip()
parts = line.split()
if not parts:
return
elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts)
elif line[0] != "!":
# ignore other commands
self.add(channel, who.nick, line)
def add(self, channel: str, who: str, line: str):
self.chains[channel][who].add(line)
def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
):
# handle markov commands
match parts[1:]:
case ["force"]:
chain = self.chains[channel][who.nick]
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]:
if nick not in self.chains[channel]:
return
chain = self.chains[channel][nick]
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case _:
# command not recognized
pass
def save(self):
if self.db_path.exists():
log.info("Copying backup of markov chain to %s")
shutil.copyfile(self.db_path, self.backup_db_path)
log.info("Done")
log.info("Saving markov chain to %s", self.db_path)
with open(self.db_path, "wb") as fp:
pickle.dump(self.chains, fp)
log.info("Done")
async def on_unload(self, conn: IrcProtocol):
self.save()
PLUGIN_TYPE = Markov