2022-05-24 19:30:42 -07:00
|
|
|
from collections import defaultdict
|
|
|
|
|
import logging
|
2022-05-30 14:29:37 -07:00
|
|
|
import math
|
2022-05-24 19:30:42 -07:00
|
|
|
from pathlib import Path
|
|
|
|
|
import random
|
2022-06-23 12:00:33 -07:00
|
|
|
import sqlite3
|
2022-06-23 14:29:02 -07:00
|
|
|
from typing import Any, DefaultDict, List, Sequence
|
2022-05-24 19:30:42 -07:00
|
|
|
|
|
|
|
|
from asyncirc.protocol import IrcProtocol
|
|
|
|
|
from irclib.parser import Prefix
|
|
|
|
|
|
|
|
|
|
from omnibot.plugin import Plugin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chain_inner_default() -> defaultdict[str | None, int]:
|
|
|
|
|
return defaultdict(int)
|
|
|
|
|
|
|
|
|
|
|
2022-05-27 18:41:17 -07:00
|
|
|
def chain_default() -> defaultdict[str, defaultdict[str | None, int]]:
|
2022-05-24 19:30:42 -07:00
|
|
|
return defaultdict(chain_inner_default)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def windows(items: Sequence[Any], size: int):
|
|
|
|
|
if len(items) < size:
|
|
|
|
|
yield items
|
|
|
|
|
else:
|
|
|
|
|
for i in range(0, len(items) - size + 1):
|
|
|
|
|
yield items[i : i + size]
|
|
|
|
|
|
|
|
|
|
|
2022-06-23 12:29:19 -07:00
|
|
|
class Chain:
|
|
|
|
|
def __init__(self, order: int, reply_chance: float, path: Path, sql_path: Path):
|
2022-06-23 12:00:33 -07:00
|
|
|
self.order = order
|
2022-06-23 12:29:19 -07:00
|
|
|
self.reply_chance = reply_chance
|
2022-06-23 12:00:33 -07:00
|
|
|
self.path = path
|
|
|
|
|
self.db = sqlite3.connect(self.path)
|
|
|
|
|
|
2022-06-23 12:29:19 -07:00
|
|
|
# Run the initial database creation script
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
with open(sql_path) as fp:
|
|
|
|
|
cursor.executescript(fp.read())
|
|
|
|
|
cursor.close()
|
|
|
|
|
self.db.commit()
|
|
|
|
|
|
2022-06-23 12:00:33 -07:00
|
|
|
def commit(self):
|
|
|
|
|
self.db.commit()
|
|
|
|
|
|
|
|
|
|
def execute(self, *args, **kwargs):
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
cursor.execute(*args, **kwargs)
|
|
|
|
|
return list(cursor.fetchall())
|
|
|
|
|
|
|
|
|
|
def get_user_id(self, channel: str, nick: str) -> int | None:
|
|
|
|
|
if result := self.execute(
|
|
|
|
|
"SELECT id FROM user WHERE channel = ? AND nick = ?", (channel, nick)
|
|
|
|
|
):
|
|
|
|
|
return result[0][0]
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
2022-06-23 12:29:19 -07:00
|
|
|
def get_reply_chance(self, channel: str, nick: str) -> float:
|
|
|
|
|
if result := self.execute(
|
|
|
|
|
"SELECT reply_chance FROM user WHERE channel = ? AND nick = ?",
|
|
|
|
|
(channel, nick),
|
|
|
|
|
):
|
|
|
|
|
return result[0][0]
|
|
|
|
|
else:
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
|
def set_reply_chance(self, channel: str, nick: str, chance: float):
|
|
|
|
|
self.ensure_user(channel, nick)
|
|
|
|
|
self.execute(
|
|
|
|
|
"UPDATE user SET reply_chance = ? WHERE channel = ? AND nick = ?",
|
|
|
|
|
(chance, channel, nick),
|
|
|
|
|
)
|
|
|
|
|
|
2022-06-23 12:00:33 -07:00
|
|
|
def ensure_user(self, channel: str, nick: str):
|
|
|
|
|
if self.get_user_id(channel, nick):
|
|
|
|
|
return
|
2022-06-23 12:29:19 -07:00
|
|
|
self.execute(
|
|
|
|
|
"INSERT INTO user (channel, nick, reply_chance) VALUES (?, ?, ?)",
|
|
|
|
|
(channel, nick, self.reply_chance),
|
|
|
|
|
)
|
2022-06-23 12:00:33 -07:00
|
|
|
|
|
|
|
|
def ensure_key(self, channel: str, nick: str, key: str, next: str):
|
|
|
|
|
assert next is not None
|
|
|
|
|
|
|
|
|
|
self.ensure_user(channel, nick)
|
|
|
|
|
if next in self.get(channel, nick, key):
|
|
|
|
|
return
|
|
|
|
|
self.execute(
|
|
|
|
|
"""
|
|
|
|
|
INSERT INTO chain (user, value, weight, next)
|
|
|
|
|
VALUES (
|
|
|
|
|
(SELECT id FROM user WHERE channel = ? AND nick = ?),
|
|
|
|
|
?, 0, ?
|
|
|
|
|
)
|
|
|
|
|
""",
|
|
|
|
|
(channel, nick, key, next),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def add(self, channel: str, nick: str, text: str, commit=True):
|
|
|
|
|
parts: List[Any] = text.strip().split()
|
|
|
|
|
if not parts:
|
|
|
|
|
return
|
|
|
|
|
for fragment in windows(parts + [""], self.order + 1):
|
|
|
|
|
head = fragment[0:-1]
|
|
|
|
|
tail = fragment[-1]
|
|
|
|
|
key = " ".join(head)
|
|
|
|
|
self.update_chain(channel, nick, key, tail)
|
|
|
|
|
if commit:
|
|
|
|
|
self.commit()
|
|
|
|
|
|
|
|
|
|
def update_chain(
|
|
|
|
|
self,
|
|
|
|
|
channel: str,
|
|
|
|
|
nick: str,
|
|
|
|
|
key: str,
|
|
|
|
|
next: str,
|
|
|
|
|
weight: int = 1,
|
|
|
|
|
):
|
|
|
|
|
self.ensure_key(channel, nick, key, next)
|
|
|
|
|
# Get if the key exists
|
|
|
|
|
self.execute(
|
|
|
|
|
"""
|
|
|
|
|
UPDATE chain
|
|
|
|
|
SET weight = weight + :weight
|
|
|
|
|
WHERE user = (SELECT id FROM user WHERE channel = :channel AND nick = :nick)
|
|
|
|
|
AND value = :key
|
|
|
|
|
AND next = :next
|
|
|
|
|
""",
|
|
|
|
|
{
|
|
|
|
|
"channel": channel,
|
|
|
|
|
"nick": nick,
|
|
|
|
|
"key": key,
|
|
|
|
|
"next": next,
|
|
|
|
|
"weight": weight,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get(self, channel: str, nick: str, key: str) -> dict[str, int]:
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT next, weight
|
|
|
|
|
FROM chain
|
|
|
|
|
WHERE
|
|
|
|
|
user = (SELECT id FROM user WHERE channel = ? AND nick = ?)
|
|
|
|
|
AND value = ?
|
|
|
|
|
""",
|
|
|
|
|
(channel, nick, key),
|
|
|
|
|
)
|
|
|
|
|
return {next: weight for next, weight in cursor.fetchall()}
|
|
|
|
|
|
|
|
|
|
def generate(self, channel: str, nick: str) -> str | None:
|
|
|
|
|
user_id = self.get_user_id(channel, nick)
|
|
|
|
|
if not user_id:
|
|
|
|
|
return None
|
|
|
|
|
words: List[str] = []
|
|
|
|
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT value
|
|
|
|
|
FROM chain
|
|
|
|
|
WHERE user = ?
|
|
|
|
|
ORDER BY RANDOM()
|
|
|
|
|
LIMIT 1
|
|
|
|
|
""",
|
|
|
|
|
(user_id,),
|
|
|
|
|
)
|
|
|
|
|
node = cursor.fetchone()[0].split(" ")
|
|
|
|
|
words += node
|
|
|
|
|
|
|
|
|
|
next: str = self.choose_next(channel, nick, " ".join(node))
|
|
|
|
|
while next:
|
|
|
|
|
words += [next]
|
|
|
|
|
node = [*node[1:], next]
|
|
|
|
|
next = self.choose_next(channel, nick, " ".join(node))
|
|
|
|
|
return " ".join(words)
|
|
|
|
|
|
|
|
|
|
def choose_next(self, channel: str, nick: str, head: str) -> str:
|
|
|
|
|
choices = self.get(channel, nick, head)
|
|
|
|
|
words = list(choices.keys())
|
|
|
|
|
weights = list(choices.values())
|
|
|
|
|
if not words:
|
|
|
|
|
return ""
|
|
|
|
|
return random.choices(words, weights)[0]
|
|
|
|
|
|
2022-06-23 14:29:02 -07:00
|
|
|
def generate_all(self) -> str | None:
|
|
|
|
|
words: List[str] = []
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT value
|
|
|
|
|
FROM chain
|
|
|
|
|
ORDER BY RANDOM()
|
|
|
|
|
LIMIT 1
|
|
|
|
|
""",
|
|
|
|
|
)
|
|
|
|
|
node = cursor.fetchone()[0].split(" ")
|
|
|
|
|
words += node
|
|
|
|
|
next: str = self.all_choose_next(" ".join(node))
|
|
|
|
|
while next:
|
|
|
|
|
words += [next]
|
|
|
|
|
node = [*node[1:], next]
|
|
|
|
|
next = self.all_choose_next(" ".join(node))
|
|
|
|
|
return " ".join(words)
|
|
|
|
|
|
|
|
|
|
def all_choose_next(self, head: str) -> str:
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT next, weight
|
|
|
|
|
FROM chain
|
|
|
|
|
WHERE value = ?
|
|
|
|
|
ORDER BY RANDOM()
|
|
|
|
|
LIMIT 1
|
|
|
|
|
""",
|
|
|
|
|
(head,),
|
|
|
|
|
)
|
|
|
|
|
choices: DefaultDict[str, int] = defaultdict(lambda: 0)
|
|
|
|
|
# Collapse all choices by weight
|
|
|
|
|
for next, weight in cursor.fetchall():
|
|
|
|
|
choices[next] += weight
|
|
|
|
|
if not choices:
|
|
|
|
|
return ""
|
|
|
|
|
words = list(choices.keys())
|
|
|
|
|
weights = list(choices.values())
|
|
|
|
|
return random.choices(words, weights)[0]
|
|
|
|
|
|
2022-06-23 12:00:33 -07:00
|
|
|
|
2022-05-24 19:30:42 -07:00
|
|
|
class Markov(Plugin):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super(Markov, self).__init__(*args, **kwargs)
|
|
|
|
|
|
2022-05-29 21:37:34 -07:00
|
|
|
self.order = int(self.plugin_config.get("order", 1))
|
2022-06-23 12:29:19 -07:00
|
|
|
self.data_path = Path(
|
|
|
|
|
self.plugin_config.get("data_path", "data/markov/markov.db")
|
|
|
|
|
)
|
|
|
|
|
self.sql_path = Path(self.plugin_config.get("sql_path", "data/markov/db.sql"))
|
2022-05-30 14:29:37 -07:00
|
|
|
self.reply_chance = float(self.plugin_config.get("reply_chance", 0.01))
|
2022-06-23 12:29:19 -07:00
|
|
|
self.chain = Chain(self.order, self.reply_chance, self.data_path, self.sql_path)
|
2022-05-24 19:30:42 -07:00
|
|
|
|
|
|
|
|
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
|
|
|
|
line = line.strip()
|
|
|
|
|
parts = line.split()
|
2022-06-23 12:29:19 -07:00
|
|
|
if not parts or who.nick == self.nick:
|
2022-05-24 19:30:42 -07:00
|
|
|
return
|
|
|
|
|
elif parts[0] == "!markov":
|
|
|
|
|
self.handle_command(conn, channel, who, parts)
|
|
|
|
|
elif line[0] != "!":
|
|
|
|
|
# ignore other commands
|
2022-05-25 19:18:37 -07:00
|
|
|
self.add(channel, who.nick, line)
|
2022-05-30 14:29:37 -07:00
|
|
|
# also, maybe generate a sentence
|
|
|
|
|
chosen = random.random()
|
2022-06-23 12:29:19 -07:00
|
|
|
reply_chance = self.chain.get_reply_chance(channel, who.nick)
|
|
|
|
|
if chosen <= reply_chance:
|
|
|
|
|
message = self.chain.generate(channel, who.nick)
|
2022-06-04 15:55:48 -07:00
|
|
|
if message:
|
|
|
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
2022-05-25 19:18:37 -07:00
|
|
|
|
2022-06-23 12:29:19 -07:00
|
|
|
def add(self, channel: str, who: str, line: str, commit=True):
|
2022-05-30 14:29:37 -07:00
|
|
|
if who == self.server_config.nick:
|
2022-05-27 17:56:52 -07:00
|
|
|
return
|
2022-06-23 12:29:19 -07:00
|
|
|
self.chain.add(channel, who, line, commit)
|
2022-05-24 19:30:42 -07:00
|
|
|
|
|
|
|
|
def handle_command(
|
|
|
|
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
|
|
|
|
):
|
|
|
|
|
# handle markov commands
|
|
|
|
|
match parts[1:]:
|
|
|
|
|
case ["force"]:
|
2022-06-23 12:29:19 -07:00
|
|
|
message = self.chain.generate(channel, who.nick)
|
2022-05-24 19:30:42 -07:00
|
|
|
if message:
|
|
|
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
|
|
|
|
case ["force", nick] | ["emulate", nick]:
|
2022-06-23 12:29:19 -07:00
|
|
|
if not self.chain.get_user_id(channel, nick):
|
2022-05-24 19:30:42 -07:00
|
|
|
return
|
2022-06-23 12:29:19 -07:00
|
|
|
message = self.chain.generate(channel, nick)
|
2022-05-24 19:30:42 -07:00
|
|
|
if message:
|
|
|
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
2022-05-27 17:56:52 -07:00
|
|
|
case ["all"]:
|
2022-06-23 14:29:02 -07:00
|
|
|
message = self.chain.generate_all()
|
2022-05-27 17:56:52 -07:00
|
|
|
if message:
|
|
|
|
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
2022-06-04 16:16:52 -07:00
|
|
|
case ["chance"]:
|
2022-06-23 12:29:19 -07:00
|
|
|
chance = self.chain.get_reply_chance(channel, who.nick)
|
2022-06-04 16:16:52 -07:00
|
|
|
self.send_to(
|
|
|
|
|
conn,
|
|
|
|
|
channel,
|
2022-06-23 12:29:19 -07:00
|
|
|
f"{who.nick}: current reply chance is {chance}",
|
2022-06-04 16:16:52 -07:00
|
|
|
)
|
2022-05-30 14:29:37 -07:00
|
|
|
case ["chance", chance]:
|
2022-05-30 14:36:28 -07:00
|
|
|
try:
|
|
|
|
|
reply_chance = float(chance)
|
|
|
|
|
except ValueError:
|
|
|
|
|
log.error("Couldn't parse %r as a float", chance)
|
|
|
|
|
return
|
2022-05-30 14:29:37 -07:00
|
|
|
if not math.isnan(reply_chance):
|
2022-06-23 12:29:19 -07:00
|
|
|
reply_chance = min(max(float(reply_chance), 0.0), self.reply_chance)
|
|
|
|
|
self.chain.set_reply_chance(channel, who.nick, reply_chance)
|
2022-05-30 14:36:28 -07:00
|
|
|
self.send_to(
|
|
|
|
|
conn,
|
|
|
|
|
channel,
|
2022-06-23 12:29:19 -07:00
|
|
|
f"{who.nick}: reply chance set to {self.chain.get_reply_chance(channel, who.nick)}",
|
2022-05-30 14:36:28 -07:00
|
|
|
)
|
2022-05-24 19:30:42 -07:00
|
|
|
case _:
|
|
|
|
|
# command not recognized
|
|
|
|
|
pass
|
|
|
|
|
|
2022-06-23 12:29:19 -07:00
|
|
|
async def save(self):
|
|
|
|
|
self.chain.commit()
|
2022-05-24 19:30:42 -07:00
|
|
|
|
|
|
|
|
async def on_unload(self, conn: IrcProtocol):
|
2022-05-26 20:59:06 -07:00
|
|
|
await self.save()
|
2022-05-24 20:40:53 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
PLUGIN_TYPE = Markov
|