diff --git a/plugins/markov.py b/plugins/markov.py index c93e69d..a8e285d 100644 --- a/plugins/markov.py +++ b/plugins/markov.py @@ -3,6 +3,7 @@ from collections import defaultdict import dataclasses import json import logging +import math from pathlib import Path import random from typing import Any, List, Mapping, Sequence @@ -35,8 +36,9 @@ def windows(items: Sequence[Any], size: int): @dataclasses.dataclass class Chain: - def __init__(self, order: int, path: Path): + def __init__(self, order: int, chance: float, path: Path): self.order = order + self.reply_chance = chance self.path = path self.__cache = chain_default() self.__last_access = 0.0 @@ -85,19 +87,23 @@ class Chain: return with open(self.path) as fp: log.info("Loading markov chain from %s", self.path) - self.__cache = defaultdict( - chain_inner_default, - { - key: defaultdict( - int, - { - (None if not word else word): weight - for word, weight in value.items() - }, - ) - for key, value in json.load(fp).items() - }, - ) + obj = json.load(fp) + + # Load the save object + self.reply_chance = obj["reply_chance"] + self.__cache = defaultdict( + chain_inner_default, + { + key: defaultdict( + int, + { + (None if not word else word): weight + for word, weight in value.items() + }, + ) + for key, value in obj["chain"] + }, + ) self.__dirty = False def save(self, retain: bool = True): @@ -106,18 +112,21 @@ class Chain: if self.__dirty: log.info("Saving markov chain to %s", self.path) self.path.parent.mkdir(parents=True, exist_ok=True) + # Build the save object + obj = { + "reply_chance": self.reply_chance, + "chain": { + key: { + ("" if word is None else word): weight + for word, weight in value.items() + } + for key, value in self.__cache.items() + }, + } with open(self.path, "w") as fp: - json.dump( - { - key: { - ("" if word is None else word): weight - for word, weight in value.items() - } - for key, value in self.__cache.items() - }, - fp, - ) + json.dump(obj, fp) self.__dirty = False + if not retain: log.debug("Pruning markov chain %s from memory", self.path) self.clear_cache() @@ -162,6 +171,7 @@ class Markov(Plugin): self.order = int(self.plugin_config.get("order", 1)) self.data_path = Path(self.plugin_config.get("data_path", "data/markov")) self.save_every = int(self.plugin_config.get("save_every", 300)) + self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01)) self.__chains = {} self.__save_loop_task = None self.__saving = asyncio.Lock() @@ -180,17 +190,22 @@ class Markov(Plugin): elif line[0] != "!": # ignore other commands self.add(channel, who.nick, line) + # also, maybe generate a sentence + chosen = random.random() + chain = self.get_chain(channel, who) + if chosen <= chain.reply_chance: + pass def get_chain(self, channel: str, who: str) -> Chain: if channel not in self.__chains: self.__chains[channel] = {} if who not in self.__chains[channel]: path = self.data_path / channel / who - self.__chains[channel][who] = Chain(self.order, path) + self.__chains[channel][who] = Chain(self.order, self.reply_chance, path) return self.__chains[channel][who] def add(self, channel: str, who: str, line: str): - if who == self.server_config.nick == who: + if who == self.server_config.nick: return chain = self.get_chain(channel, who) chain.add(line) @@ -208,7 +223,7 @@ class Markov(Plugin): if message: self.send_to(conn, channel, f"{who.nick}: {message}") case ["force", nick] | ["emulate", nick]: - chain = self.get_chain(channel, who.nick) + chain = self.get_chain(channel, nick) if not chain: return message = chain.generate() @@ -219,6 +234,13 @@ class Markov(Plugin): message = chain.generate() if message: self.send_to(conn, channel, f"{who.nick}: {message}") + case ["chance", chance]: + chain = self.get_chain(channel, who.nick) + reply_chance = float(chance) + if not math.isnan(reply_chance): + chain.reply_chance = min( + max(float(reply_chance), 0.0), self.reply_chance + ) case _: # command not recognized pass diff --git a/tools/markov_import.py b/tools/markov_import.py index 43af4fd..5270f00 100644 --- a/tools/markov_import.py +++ b/tools/markov_import.py @@ -38,8 +38,9 @@ async def main(): for line in lines: if mat := LINE_RE.search(line): name = mat["name"] - message = mat["message"] - plugin.add(channel, name, message) + message = mat["message"].strip() + if name != server_config.nick and message and message[0] != "!": + plugin.add(channel, name, message) await plugin.save()