from collections import defaultdict import logging import random import sqlite3 from typing import Any, DefaultDict, List, Optional, Sequence log = logging.getLogger(__name__) def chain_inner_default() -> defaultdict[Optional[str], int]: return defaultdict(int) def chain_default() -> defaultdict[str, defaultdict[Optional[str], int]]: return defaultdict(chain_inner_default) def windows(items: Sequence[Any], size: int): if len(items) < size: yield items else: for i in range(0, len(items) - size + 1): yield items[i : i + size] class Chain: def __init__(self, order: int, reply_chance: float, db: sqlite3.Connection): self.order = order self.reply_chance = reply_chance self.db = db def commit(self): self.db.commit() def execute(self, *args, **kwargs): cursor = self.db.cursor() cursor.execute(*args, **kwargs) return list(cursor.fetchall()) def get_user_listen(self, guild_id: int, member_id: int) -> bool: self.ensure_user(guild_id, member_id) if result := self.execute( "SELECT listen FROM user WHERE guild_id = ? AND member_id = ?", (guild_id, member_id), ): return bool(result[0][0]) else: return True def set_user_listen(self, guild_id: int, member_id: int, listen: bool): self.ensure_user(guild_id, member_id) self.execute( "UPDATE user SET listen = ? WHERE guild_id = ? AND member_id = ?", (listen, guild_id, member_id), ) self.commit() def get_reply_chance(self, guild_id: int, member_id: int) -> float: self.ensure_user(guild_id, member_id) if result := self.execute( "SELECT reply_chance FROM user WHERE guild_id = ? AND member_id = ?", (guild_id, member_id), ): return result[0][0] else: return 0.0 def set_reply_chance(self, guild_id: int, member_id: int, chance: float): self.ensure_user(guild_id, member_id) self.execute( "UPDATE user SET reply_chance = ? WHERE guild_id = ? AND member_id = ?", (chance, guild_id, member_id), ) self.commit() def get_user_id(self, guild_id: int, member_id: int) -> Optional[int]: if result := self.execute( "SELECT id FROM user WHERE guild_id = ? AND member_id = ?", (guild_id, member_id), ): return result[0][0] else: return None def ensure_user(self, guild_id: int, member_id: int): if self.get_user_id(guild_id, member_id): return self.execute( "INSERT INTO user (guild_id, member_id, reply_chance) VALUES (?, ?, ?)", (guild_id, member_id, self.reply_chance), ) def ensure_key(self, guild_id: int, member_id: int, key: str, next: str): assert next is not None self.ensure_user(guild_id, member_id) if next in self.get(guild_id, member_id, key): return self.execute( """ INSERT INTO chain (user, value, weight, next) VALUES ( (SELECT id FROM user WHERE guild_id = ? AND member_id = ?), ?, 0, ? ) """, (guild_id, member_id, key, next), ) def add(self, guild_id: int, member_id: int, text: str, commit=True): parts: List[Any] = text.strip().split() if not parts: return for fragment in windows(parts + [""], self.order + 1): head = fragment[0:-1] tail = fragment[-1] key = " ".join(head) self.update_chain(guild_id, member_id, key, tail) if commit: self.commit() def update_chain( self, guild_id: int, member_id: int, key: str, next: str, weight: int = 1, ): self.ensure_key(guild_id, member_id, key, next) # Get if the key exists self.execute( """ UPDATE chain SET weight = weight + :weight WHERE user = (SELECT id FROM user WHERE guild_id = :guild_id AND member_id = :member_id) AND value = :key AND next = :next """, { "guild_id": guild_id, "member_id": member_id, "key": key, "next": next, "weight": weight, }, ) def get(self, guild_id: int, member_id: int, key: str) -> dict[str, int]: cursor = self.db.cursor() cursor.execute( """ SELECT next, weight FROM chain WHERE user = (SELECT id FROM user WHERE guild_id = ? AND member_id = ?) AND value = ? """, (guild_id, member_id, key), ) return {next: weight for next, weight in cursor.fetchall()} def generate(self, guild_id: int, member_id: int) -> Optional[str]: user_id = self.get_user_id(guild_id, member_id) if not user_id: return None words: List[str] = [] cursor = self.db.cursor() cursor.execute( """ SELECT value FROM chain WHERE user = ? ORDER BY RANDOM() LIMIT 1 """, (user_id,), ) node = cursor.fetchone()[0].split(" ") words += node next: str = self.choose_next(guild_id, member_id, " ".join(node)) while next: words += [next] node = [*node[1:], next] next = self.choose_next(guild_id, member_id, " ".join(node)) return " ".join(words) def choose_next(self, guild_id: int, member_id: int, head: str) -> str: choices = self.get(guild_id, member_id, head) words = list(choices.keys()) weights = list(choices.values()) if not words: return "" return random.choices(words, weights)[0] def generate_all(self, guild_id: int) -> Optional[str]: words: List[str] = [] cursor = self.db.cursor() cursor.execute( """ SELECT value FROM chain WHERE user IN (SELECT user.id FROM user WHERE guild_id = ?) ORDER BY RANDOM() LIMIT 1 """, (guild_id,), ) first = cursor.fetchone() if not first: return None node = first[0].split(" ") words += node next: str = self.all_choose_next(guild_id, " ".join(node)) while next: words += [next] node = [*node[1:], next] next = self.all_choose_next(guild_id, " ".join(node)) return " ".join(words) def all_choose_next(self, guild_id: int, head: str) -> str: cursor = self.db.cursor() cursor.execute( """ SELECT next, weight FROM chain WHERE value = ? ORDER BY RANDOM() LIMIT 1 """, (head,), ) choices: DefaultDict[str, int] = defaultdict(lambda: 0) # Collapse all choices by weight for next, weight in cursor.fetchall(): choices[next] += weight if not choices: return "" words = list(choices.keys()) weights = list(choices.values()) return random.choices(words, weights)[0]