Add initial markov bot plugin

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
2022-05-24 19:30:42 -07:00
parent 82e50f86d6
commit 9c188e30b1

142
plugins/markov.py Normal file
View File

@@ -0,0 +1,142 @@
from collections import defaultdict
import dataclasses
from functools import partial
import logging
from pathlib import Path
import pickle
import random
import shutil
from typing import Any, List, MutableMapping, Sequence
from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix
from omnibot.plugin import Plugin
log = logging.getLogger(__name__)
def chain_inner_default() -> defaultdict[str | None, int]:
return defaultdict(int)
def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]:
return defaultdict(chain_inner_default)
def windows(items: Sequence[Any], size: int):
if len(items) < size:
yield items
else:
for i in range(0, len(items) - size + 1):
yield items[i : i + size]
@dataclasses.dataclass
class Chain:
order: int
chain: defaultdict[
tuple[str, ...], defaultdict[str | None, int]
] = dataclasses.field(default_factory=chain_default)
def add(self, text: str):
parts: List[Any] = text.strip().split()
if not parts:
return
if len(parts) < self.order:
self.chain[tuple(parts)][None] += 1
else:
for fragment in windows(parts + [None], self.order + 1):
head = tuple(fragment[0:-1])
tail = fragment[-1]
self.chain[head][tail] += 1
def generate(self) -> str:
words: List[str] = []
if not self.chain:
return ""
node = random.choice(list(self.chain.keys()))
words += node
next: str | None = self.choose_next(node)
while next:
words += [next]
node = (*node[1:], next)
next = self.choose_next(node)
return " ".join(words)
def choose_next(self, head: tuple[str, ...]) -> str | None:
if head not in self.chain:
return None
choices = self.chain[head]
words = list(choices.keys())
weights = list(choices.values())
if not words:
return None
return random.choices(words, weights)[0]
class Markov(Plugin):
def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs)
self.order = self.plugin_config.get("order", 1)
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl"))
self.backup_db_path = Path(str(self.db_path) + ".backup")
# Load chain from data dir
if self.db_path.exists():
with open(self.db_path, "rb") as fp:
self.chains = pickle.load(fp)
else:
self.chains = defaultdict(
partial(defaultdict, partial(Chain, order=self.order))
)
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip()
parts = line.split()
if not parts:
return
elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts)
elif line[0] != "!":
# ignore other commands
self.chains[channel][who.nick].add(line)
def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
):
# handle markov commands
match parts[1:]:
case ["force"]:
chain = self.chains[channel][who.nick]
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]:
if nick not in self.chains[channel]:
return
chain = self.chains[channel][nick]
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case _:
# command not recognized
pass
def save(self):
if self.db_path.exists():
log.info("Copying backup of markov chain to %s")
shutil.copyfile(self.db_path, self.backup_db_path)
log.info("Done")
log.info("Saving markov chain to %s", self.db_path)
with open(self.db_path, "wb") as fp:
pickle.dump(self.chains, fp)
log.info("Done")
async def on_unload(self, conn: IrcProtocol):
self.save()