Add initial markov bot plugin
Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
142
plugins/markov.py
Normal file
142
plugins/markov.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from collections import defaultdict
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import random
|
||||
import shutil
|
||||
from typing import Any, List, MutableMapping, Sequence
|
||||
|
||||
from asyncirc.protocol import IrcProtocol
|
||||
from irclib.parser import Prefix
|
||||
|
||||
from omnibot.plugin import Plugin
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chain_inner_default() -> defaultdict[str | None, int]:
|
||||
return defaultdict(int)
|
||||
|
||||
|
||||
def chain_default() -> defaultdict[tuple[str, ...], defaultdict[str | None, int]]:
|
||||
return defaultdict(chain_inner_default)
|
||||
|
||||
|
||||
def windows(items: Sequence[Any], size: int):
|
||||
if len(items) < size:
|
||||
yield items
|
||||
else:
|
||||
for i in range(0, len(items) - size + 1):
|
||||
yield items[i : i + size]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Chain:
|
||||
order: int
|
||||
chain: defaultdict[
|
||||
tuple[str, ...], defaultdict[str | None, int]
|
||||
] = dataclasses.field(default_factory=chain_default)
|
||||
|
||||
def add(self, text: str):
|
||||
parts: List[Any] = text.strip().split()
|
||||
if not parts:
|
||||
return
|
||||
if len(parts) < self.order:
|
||||
self.chain[tuple(parts)][None] += 1
|
||||
else:
|
||||
for fragment in windows(parts + [None], self.order + 1):
|
||||
head = tuple(fragment[0:-1])
|
||||
tail = fragment[-1]
|
||||
self.chain[head][tail] += 1
|
||||
|
||||
def generate(self) -> str:
|
||||
words: List[str] = []
|
||||
if not self.chain:
|
||||
return ""
|
||||
|
||||
node = random.choice(list(self.chain.keys()))
|
||||
words += node
|
||||
next: str | None = self.choose_next(node)
|
||||
while next:
|
||||
words += [next]
|
||||
node = (*node[1:], next)
|
||||
next = self.choose_next(node)
|
||||
|
||||
return " ".join(words)
|
||||
|
||||
def choose_next(self, head: tuple[str, ...]) -> str | None:
|
||||
if head not in self.chain:
|
||||
return None
|
||||
choices = self.chain[head]
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
if not words:
|
||||
return None
|
||||
return random.choices(words, weights)[0]
|
||||
|
||||
|
||||
class Markov(Plugin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Markov, self).__init__(*args, **kwargs)
|
||||
|
||||
self.order = self.plugin_config.get("order", 1)
|
||||
self.db_path = Path(self.plugin_config.get("db_path", "data/markov.pkl"))
|
||||
self.backup_db_path = Path(str(self.db_path) + ".backup")
|
||||
|
||||
# Load chain from data dir
|
||||
if self.db_path.exists():
|
||||
with open(self.db_path, "rb") as fp:
|
||||
self.chains = pickle.load(fp)
|
||||
else:
|
||||
self.chains = defaultdict(
|
||||
partial(defaultdict, partial(Chain, order=self.order))
|
||||
)
|
||||
|
||||
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
||||
line = line.strip()
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
return
|
||||
elif parts[0] == "!markov":
|
||||
self.handle_command(conn, channel, who, parts)
|
||||
elif line[0] != "!":
|
||||
# ignore other commands
|
||||
self.chains[channel][who.nick].add(line)
|
||||
|
||||
def handle_command(
|
||||
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
||||
):
|
||||
# handle markov commands
|
||||
match parts[1:]:
|
||||
case ["force"]:
|
||||
chain = self.chains[channel][who.nick]
|
||||
message = chain.generate()
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
case ["force", nick] | ["emulate", nick]:
|
||||
if nick not in self.chains[channel]:
|
||||
return
|
||||
chain = self.chains[channel][nick]
|
||||
message = chain.generate()
|
||||
if message:
|
||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||
case _:
|
||||
# command not recognized
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
if self.db_path.exists():
|
||||
log.info("Copying backup of markov chain to %s")
|
||||
shutil.copyfile(self.db_path, self.backup_db_path)
|
||||
log.info("Done")
|
||||
|
||||
log.info("Saving markov chain to %s", self.db_path)
|
||||
with open(self.db_path, "wb") as fp:
|
||||
pickle.dump(self.chains, fp)
|
||||
log.info("Done")
|
||||
|
||||
async def on_unload(self, conn: IrcProtocol):
|
||||
self.save()
|
||||
Reference in New Issue
Block a user