Whenever someone says something, there's a chance that markov will interject his opinion. Users can also set the chance between 0.0 and the default value (in the config) if they want to see markov replies less often. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
273 lines
8.8 KiB
Python
273 lines
8.8 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
import math
|
|
from pathlib import Path
|
|
import random
|
|
from typing import Any, List, Mapping, Sequence
|
|
|
|
from asyncirc.protocol import IrcProtocol
|
|
from irclib.parser import Prefix
|
|
|
|
from omnibot.plugin import Plugin
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
ALLCHAIN = "ALL!CHAIN"
|
|
|
|
|
|
def chain_inner_default() -> defaultdict[str | None, int]:
|
|
return defaultdict(int)
|
|
|
|
|
|
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
|
|
return defaultdict(chain_inner_default)
|
|
|
|
|
|
def windows(items: Sequence[Any], size: int):
|
|
if len(items) < size:
|
|
yield items
|
|
else:
|
|
for i in range(0, len(items) - size + 1):
|
|
yield items[i : i + size]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Chain:
|
|
def __init__(self, order: int, chance: float, path: Path):
|
|
self.order = order
|
|
self.reply_chance = chance
|
|
self.path = path
|
|
self.__cache = chain_default()
|
|
self.__last_access = 0.0
|
|
self.__dirty = False
|
|
|
|
@property
|
|
def last_access(self) -> float:
|
|
return self.__last_access
|
|
|
|
def add(self, text: str):
|
|
parts: List[Any] = text.strip().split()
|
|
if not parts:
|
|
return
|
|
self.__load()
|
|
for fragment in windows(parts + [None], self.order + 1):
|
|
head = fragment[0:-1]
|
|
tail = fragment[-1]
|
|
self.__cache[" ".join(head)][tail] += 1
|
|
self.__dirty = True
|
|
|
|
def get(self, key: str) -> dict[str | None, int]:
|
|
self.__last_access = asyncio.get_running_loop().time()
|
|
if self.__cache:
|
|
if key in self.__cache:
|
|
return self.__cache[key]
|
|
else:
|
|
raise KeyError(key)
|
|
else:
|
|
# Load cache, then return key
|
|
self.__load()
|
|
return self.__cache[key]
|
|
|
|
def set(self, key: str, value: Mapping[str | None, int]):
|
|
self.__last_access = asyncio.get_running_loop().time()
|
|
if not self.__cache:
|
|
# Attempt the cache before writing to it
|
|
self.__load()
|
|
self.__cache[key] = defaultdict(int, value)
|
|
self.__dirty = True
|
|
|
|
def __load(self):
|
|
self.__last_access = asyncio.get_running_loop().time()
|
|
if self.__cache:
|
|
return
|
|
if not self.path.exists():
|
|
return
|
|
with open(self.path) as fp:
|
|
log.info("Loading markov chain from %s", self.path)
|
|
obj = json.load(fp)
|
|
|
|
# Load the save object
|
|
self.reply_chance = obj["reply_chance"]
|
|
self.__cache = defaultdict(
|
|
chain_inner_default,
|
|
{
|
|
key: defaultdict(
|
|
int,
|
|
{
|
|
(None if not word else word): weight
|
|
for word, weight in value.items()
|
|
},
|
|
)
|
|
for key, value in obj["chain"]
|
|
},
|
|
)
|
|
self.__dirty = False
|
|
|
|
def save(self, retain: bool = True):
|
|
if not self.__cache:
|
|
return
|
|
if self.__dirty:
|
|
log.info("Saving markov chain to %s", self.path)
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
# Build the save object
|
|
obj = {
|
|
"reply_chance": self.reply_chance,
|
|
"chain": {
|
|
key: {
|
|
("" if word is None else word): weight
|
|
for word, weight in value.items()
|
|
}
|
|
for key, value in self.__cache.items()
|
|
},
|
|
}
|
|
with open(self.path, "w") as fp:
|
|
json.dump(obj, fp)
|
|
self.__dirty = False
|
|
|
|
if not retain:
|
|
log.debug("Pruning markov chain %s from memory", self.path)
|
|
self.clear_cache()
|
|
|
|
def clear_cache(self):
|
|
self.__cache.clear()
|
|
self.__dirty = False
|
|
|
|
def __bool__(self) -> bool:
|
|
return self.path.exists() or bool(self.__cache)
|
|
|
|
def generate(self) -> str:
|
|
self.__load()
|
|
if not self.__cache:
|
|
return ""
|
|
|
|
words: List[str] = []
|
|
|
|
node = random.choice(list(self.__cache.keys())).split(" ")
|
|
words += node
|
|
next: str | None = self.choose_next(" ".join(node))
|
|
while next:
|
|
words += [next]
|
|
node = [*node[1:], next]
|
|
next = self.choose_next(" ".join(node))
|
|
return " ".join(words)
|
|
|
|
def choose_next(self, head: str) -> str | None:
|
|
self.__load()
|
|
choices = self.__cache[head]
|
|
words = list(choices.keys())
|
|
weights = list(choices.values())
|
|
if not words:
|
|
return None
|
|
return random.choices(words, weights)[0]
|
|
|
|
|
|
class Markov(Plugin):
|
|
def __init__(self, *args, **kwargs):
|
|
super(Markov, self).__init__(*args, **kwargs)
|
|
|
|
self.order = int(self.plugin_config.get("order", 1))
|
|
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
|
|
self.save_every = int(self.plugin_config.get("save_every", 300))
|
|
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
|
|
self.__chains = {}
|
|
self.__save_loop_task = None
|
|
self.__saving = asyncio.Lock()
|
|
|
|
async def on_load(self):
|
|
loop = asyncio.get_running_loop()
|
|
self.__save_loop_task = loop.create_task(self.__save_loop())
|
|
|
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
|
line = line.strip()
|
|
parts = line.split()
|
|
if not parts:
|
|
return
|
|
elif parts[0] == "!markov":
|
|
self.handle_command(conn, channel, who, parts)
|
|
elif line[0] != "!":
|
|
# ignore other commands
|
|
self.add(channel, who.nick, line)
|
|
# also, maybe generate a sentence
|
|
chosen = random.random()
|
|
chain = self.get_chain(channel, who)
|
|
if chosen <= chain.reply_chance:
|
|
pass
|
|
|
|
def get_chain(self, channel: str, who: str) -> Chain:
|
|
if channel not in self.__chains:
|
|
self.__chains[channel] = {}
|
|
if who not in self.__chains[channel]:
|
|
path = self.data_path / channel / who
|
|
self.__chains[channel][who] = Chain(self.order, self.reply_chance, path)
|
|
return self.__chains[channel][who]
|
|
|
|
def add(self, channel: str, who: str, line: str):
|
|
if who == self.server_config.nick:
|
|
return
|
|
chain = self.get_chain(channel, who)
|
|
chain.add(line)
|
|
allchain = self.get_chain(channel, ALLCHAIN)
|
|
allchain.add(line)
|
|
|
|
def handle_command(
|
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
|
):
|
|
# handle markov commands
|
|
match parts[1:]:
|
|
case ["force"]:
|
|
chain = self.get_chain(channel, who.nick)
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["force", nick] | ["emulate", nick]:
|
|
chain = self.get_chain(channel, nick)
|
|
if not chain:
|
|
return
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["all"]:
|
|
chain = self.get_chain(channel, ALLCHAIN)
|
|
message = chain.generate()
|
|
if message:
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
case ["chance", chance]:
|
|
chain = self.get_chain(channel, who.nick)
|
|
reply_chance = float(chance)
|
|
if not math.isnan(reply_chance):
|
|
chain.reply_chance = min(
|
|
max(float(reply_chance), 0.0), self.reply_chance
|
|
)
|
|
case _:
|
|
# command not recognized
|
|
pass
|
|
|
|
async def __save_loop(self):
|
|
while True:
|
|
log.debug("Pruning inactive markov chains in %s seconds", self.save_every)
|
|
await asyncio.sleep(self.save_every)
|
|
retain_after = asyncio.get_running_loop().time() - self.save_every
|
|
await self.save(retain_after=retain_after)
|
|
|
|
async def save(self, retain_after: float | None = None):
|
|
async with self.__saving:
|
|
log.info("Saving markov chains")
|
|
for chains in self.__chains.values():
|
|
for chain in chains.values():
|
|
if retain_after is not None:
|
|
retain = chain.last_access > retain_after
|
|
else:
|
|
retain = True
|
|
chain.save(retain=retain)
|
|
log.info("Done")
|
|
|
|
async def on_unload(self, conn: IrcProtocol):
|
|
self.__save_loop_task.cancel()
|
|
await self.save()
|
|
|
|
|
|
PLUGIN_TYPE = Markov
|