Files
omnibot22/plugins/markov.py

458 lines
15 KiB
Python
Raw Normal View History

import asyncio
from collections import defaultdict
import dataclasses
import json
import logging
import math
from pathlib import Path
import random
import sqlite3
from typing import Any, List, Mapping, Sequence
from asyncirc.protocol import IrcProtocol
from irclib.parser import Prefix
from omnibot.plugin import Plugin
log = logging.getLogger(__name__)
ALLCHAIN = "ALL!CHAIN"
def chain_inner_default() -> defaultdict[str | None, int]:
return defaultdict(int)
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
return defaultdict(chain_inner_default)
def windows(items: Sequence[Any], size: int):
if len(items) < size:
yield items
else:
for i in range(0, len(items) - size + 1):
yield items[i : i + size]
class DbChain:
def __init__(self, order: int, path: Path):
self.order = order
self.path = path
self.db = sqlite3.connect(self.path)
def commit(self):
self.db.commit()
def execute(self, *args, **kwargs):
cursor = self.db.cursor()
cursor.execute(*args, **kwargs)
return list(cursor.fetchall())
def get_user_id(self, channel: str, nick: str) -> int | None:
if result := self.execute(
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
):
return result[0][0]
else:
return None
def ensure_user(self, channel: str, nick: str):
if self.get_user_id(channel, nick):
return
self.execute("INSERT INTO user (channel, nick) VALUES (?, ?)", (channel, nick))
def ensure_key(self, channel: str, nick: str, key: str, next: str):
assert next is not None
self.ensure_user(channel, nick)
if next in self.get(channel, nick, key):
return
self.execute(
"""
INSERT INTO chain (user, value, weight, next)
VALUES (
(SELECT id FROM user WHERE channel = ? AND nick = ?),
?, 0, ?
)
""",
(channel, nick, key, next),
)
def add(self, channel: str, nick: str, text: str, commit=True):
parts: List[Any] = text.strip().split()
if not parts:
return
for fragment in windows(parts + [""], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
key = " ".join(head)
self.update_chain(channel, nick, key, tail)
if commit:
self.commit()
def update_chain(
self,
channel: str,
nick: str,
key: str,
next: str,
weight: int = 1,
):
self.ensure_key(channel, nick, key, next)
# Get if the key exists
self.execute(
"""
UPDATE chain
SET weight = weight + :weight
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
AND value = :key
AND next = :next
""",
{
"channel": channel,
"nick": nick,
"key": key,
"next": next,
"weight": weight,
},
)
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
cursor = self.db.cursor()
cursor.execute(
"""
SELECT next, weight
FROM chain
WHERE
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
AND value = ?
""",
(channel, nick, key),
)
return {next: weight for next, weight in cursor.fetchall()}
def generate(self, channel: str, nick: str) -> str | None:
user_id = self.get_user_id(channel, nick)
if not user_id:
return None
words: List[str] = []
cursor = self.db.cursor()
cursor.execute(
"""
SELECT value
FROM chain
WHERE user = ?
ORDER BY RANDOM()
LIMIT 1
""",
(user_id,),
)
node = cursor.fetchone()[0].split(" ")
words += node
next: str = self.choose_next(channel, nick, " ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(channel, nick, " ".join(node))
return " ".join(words)
def choose_next(self, channel: str, nick: str, head: str) -> str:
choices = self.get(channel, nick, head)
words = list(choices.keys())
weights = list(choices.values())
if not words:
return ""
return random.choices(words, weights)[0]
@dataclasses.dataclass
class Chain:
def __init__(self, order: int, chance: float, path: Path):
self.order = order
self.__reply_chance = chance
self.path = path
self.__cache = chain_default()
self.__last_access = 0.0
self.__dirty = False
def __touch(self):
self.__last_access = asyncio.get_running_loop().time()
@property
def reply_chance(self) -> float:
self.__load()
return self.__reply_chance
@reply_chance.setter
def reply_chance(self, val: float):
if not (isinstance(val, float) or isinstance(val, int)):
return NotImplemented
self.__load()
self.__reply_chance = val
self.__dirty = True
@property
def last_access(self) -> float:
return self.__last_access
def add(self, text: str):
parts: List[Any] = text.strip().split()
if not parts:
return
self.__touch()
self.__load()
self.__dirty = True
for fragment in windows(parts + [None], self.order + 1):
head = fragment[0:-1]
tail = fragment[-1]
self.__cache[" ".join(head)][tail] += 1
def get(self, key: str) -> dict[str | None, int]:
if self.__cache:
if key in self.__cache:
self.__touch()
return self.__cache[key]
else:
raise KeyError(key)
else:
# Load cache, then return key
self.__load()
self.__touch()
return self.__cache[key]
def set(self, key: str, value: Mapping[str | None, int]):
self.__touch()
if not self.__cache:
# Attempt the cache before writing to it
self.__load()
self.__cache[key] = defaultdict(int, value)
self.__dirty = True
def __load(self):
self.__touch()
if self.__cache:
return
if not self.path.exists():
return
with open(self.path) as fp:
log.info("Loading markov chain from %s", self.path)
obj = json.load(fp)
# Load the save object
self.__reply_chance = obj["reply_chance"]
self.__cache = defaultdict(
chain_inner_default,
{
key: defaultdict(
int,
{
(None if not word else word): weight
for word, weight in value.items()
},
)
for key, value in obj["chain"].items()
},
)
self.__dirty = False
def save(self):
if not self.__cache:
return
if self.__dirty:
log.info("Saving markov chain to %s", self.path)
self.path.parent.mkdir(parents=True, exist_ok=True)
# Build the save object
obj = {
"reply_chance": self.__reply_chance,
"chain": {
key: {
("" if word is None else word): weight
for word, weight in value.items()
}
for key, value in self.__cache.items()
},
}
with open(self.path, "w") as fp:
json.dump(obj, fp)
self.__dirty = False
else:
log.info("Chain %s is not dirty, not saving", self.path)
def clear_cache(self):
self.__cache.clear()
self.__dirty = False
def __bool__(self) -> bool:
return self.path.exists() or bool(self.__cache)
def generate(self) -> str:
self.__load()
if not self.__cache:
return ""
words: List[str] = []
node = random.choice(list(self.__cache.keys())).split(" ")
words += node
next: str | None = self.choose_next(" ".join(node))
while next:
words += [next]
node = [*node[1:], next]
next = self.choose_next(" ".join(node))
return " ".join(words)
def choose_next(self, head: str) -> str | None:
self.__load()
choices = self.__cache[head]
words = list(choices.keys())
weights = list(choices.values())
if not words:
return None
return random.choices(words, weights)[0]
class Markov(Plugin):
def __init__(self, *args, **kwargs):
super(Markov, self).__init__(*args, **kwargs)
self.order = int(self.plugin_config.get("order", 1))
self.data_path = Path(self.plugin_config.get("data_path", "data/markov"))
self.save_every = int(self.plugin_config.get("save_every", 1800))
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
self.__chains = {}
self.__save_loop_task = None
self.__saving = asyncio.Lock()
async def on_load(self):
loop = asyncio.get_running_loop()
self.__save_loop_task = loop.create_task(self.__save_loop())
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
line = line.strip()
parts = line.split()
if not parts:
return
elif parts[0] == "!markov":
self.handle_command(conn, channel, who, parts)
elif line[0] != "!":
# ignore other commands
self.add(channel, who.nick, line)
# also, maybe generate a sentence
chosen = random.random()
chain = self.get_chain(channel, who.nick)
if chosen <= chain.reply_chance:
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
def get_chain(self, channel: str, who: str) -> Chain:
if channel not in self.__chains:
self.__chains[channel] = {}
if who not in self.__chains[channel]:
path = self.data_path / channel / who
self.__chains[channel][who] = Chain(self.order, self.reply_chance, path)
return self.__chains[channel][who]
def add(self, channel: str, who: str, line: str):
if who == self.server_config.nick:
return
chain = self.get_chain(channel, who)
chain.add(line)
allchain = self.get_chain(channel, ALLCHAIN)
allchain.add(line)
def handle_command(
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
):
# handle markov commands
match parts[1:]:
case ["force"]:
chain = self.get_chain(channel, who.nick)
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["force", nick] | ["emulate", nick]:
chain = self.get_chain(channel, nick)
if not chain:
return
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["all"]:
chain = self.get_chain(channel, ALLCHAIN)
message = chain.generate()
if message:
self.send_to(conn, channel, f"{who.nick}: {message}")
case ["chance"]:
chain = self.get_chain(channel, who.nick)
self.send_to(
conn,
channel,
f"{who.nick}: current reply chance is {chain.reply_chance}",
)
case ["chance", chance]:
chain = self.get_chain(channel, who.nick)
try:
reply_chance = float(chance)
except ValueError:
log.error("Couldn't parse %r as a float", chance)
return
if not math.isnan(reply_chance):
chain.reply_chance = min(
max(float(reply_chance), 0.0), self.reply_chance
)
self.send_to(
conn,
channel,
f"{who.nick}: reply chance set to {chain.reply_chance}",
)
case _:
# command not recognized
pass
async def __save_loop(self):
while True:
log.debug("Saving markov chains in %s seconds", self.save_every)
await asyncio.sleep(self.save_every)
retain_after = asyncio.get_running_loop().time() - self.save_every
await self.save(retain_after=retain_after)
async def save(self, retain_after: float | None = None):
async with self.__saving:
from concurrent.futures import ProcessPoolExecutor
log.info("Saving markov chains")
coros = []
loop = asyncio.get_running_loop()
# ProcessPoolExecutor is an explicit decision I've made to use,
# because it allows us to save in a different process, with
# different memory, and simultaneously clear it if it needs to be
# cleared.
with ProcessPoolExecutor() as pool:
for chains in self.__chains.values():
for chain in chains.values():
# Start the save in a new process, in a new task.
log.debug("Starting process to save %s", chain.path)
coro = loop.run_in_executor(pool, chain.save)
coros += [coro]
# Prune
retain = True
if retain_after is not None:
retain = chain.last_access > retain_after
if not retain:
log.info("Pruning markov chain %s from memory", chain.path)
chain.clear_cache()
if coros:
await asyncio.gather(*coros)
log.info("Done")
async def on_unload(self, conn: IrcProtocol):
self.__save_loop_task.cancel()
await self.save()
PLUGIN_TYPE = Markov