Initial commit
The bot is currently working under 3.10, need to check out 3.9 next. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
244
discord_markov/markov.py
Normal file
244
discord_markov/markov.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import random
|
||||
import sqlite3
|
||||
from typing import Any, DefaultDict, List, Optional, Sequence
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chain_inner_default() -> defaultdict[Optional[str], int]:
|
||||
return defaultdict(int)
|
||||
|
||||
|
||||
def chain_default() -> defaultdict[str, defaultdict[Optional[str], int]]:
|
||||
return defaultdict(chain_inner_default)
|
||||
|
||||
|
||||
def windows(items: Sequence[Any], size: int):
|
||||
if len(items) < size:
|
||||
yield items
|
||||
else:
|
||||
for i in range(0, len(items) - size + 1):
|
||||
yield items[i : i + size]
|
||||
|
||||
|
||||
class Chain:
|
||||
def __init__(self, order: int, reply_chance: float, db: sqlite3.Connection):
|
||||
self.order = order
|
||||
self.reply_chance = reply_chance
|
||||
self.db = db
|
||||
|
||||
def commit(self):
|
||||
self.db.commit()
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(*args, **kwargs)
|
||||
return list(cursor.fetchall())
|
||||
|
||||
def get_user_listen(self, guild_id: int, member_id: int) -> bool:
|
||||
self.ensure_user(guild_id, member_id)
|
||||
if result := self.execute(
|
||||
"SELECT listen FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return bool(result[0][0])
|
||||
else:
|
||||
return True
|
||||
|
||||
def set_user_listen(self, guild_id: int, member_id: int, listen: bool):
|
||||
self.ensure_user(guild_id, member_id)
|
||||
self.execute(
|
||||
"UPDATE user SET listen = ? WHERE guild_id = ? AND member_id = ?",
|
||||
(listen, guild_id, member_id),
|
||||
)
|
||||
self.commit()
|
||||
|
||||
def get_reply_chance(self, guild_id: int, member_id: int) -> float:
|
||||
self.ensure_user(guild_id, member_id)
|
||||
if result := self.execute(
|
||||
"SELECT reply_chance FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return result[0][0]
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def set_reply_chance(self, guild_id: int, member_id: int, chance: float):
|
||||
self.ensure_user(guild_id, member_id)
|
||||
self.execute(
|
||||
"UPDATE user SET reply_chance = ? WHERE guild_id = ? AND member_id = ?",
|
||||
(chance, guild_id, member_id),
|
||||
)
|
||||
self.commit()
|
||||
|
||||
def get_user_id(self, guild_id: int, member_id: int) -> Optional[int]:
|
||||
if result := self.execute(
|
||||
"SELECT id FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return result[0][0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def ensure_user(self, guild_id: int, member_id: int):
|
||||
if self.get_user_id(guild_id, member_id):
|
||||
return
|
||||
self.execute(
|
||||
"INSERT INTO user (guild_id, member_id, reply_chance) VALUES (?, ?, ?)",
|
||||
(guild_id, member_id, self.reply_chance),
|
||||
)
|
||||
|
||||
def ensure_key(self, guild_id: int, member_id: int, key: str, next: str):
|
||||
assert next is not None
|
||||
|
||||
self.ensure_user(guild_id, member_id)
|
||||
if next in self.get(guild_id, member_id, key):
|
||||
return
|
||||
self.execute(
|
||||
"""
|
||||
INSERT INTO chain (user, value, weight, next)
|
||||
VALUES (
|
||||
(SELECT id FROM user WHERE guild_id = ? AND member_id = ?),
|
||||
?, 0, ?
|
||||
)
|
||||
""",
|
||||
(guild_id, member_id, key, next),
|
||||
)
|
||||
|
||||
def add(self, guild_id: int, member_id: int, text: str, commit=True):
|
||||
parts: List[Any] = text.strip().split()
|
||||
if not parts:
|
||||
return
|
||||
for fragment in windows(parts + [""], self.order + 1):
|
||||
head = fragment[0:-1]
|
||||
tail = fragment[-1]
|
||||
key = " ".join(head)
|
||||
self.update_chain(guild_id, member_id, key, tail)
|
||||
if commit:
|
||||
self.commit()
|
||||
|
||||
def update_chain(
|
||||
self,
|
||||
guild_id: int,
|
||||
member_id: int,
|
||||
key: str,
|
||||
next: str,
|
||||
weight: int = 1,
|
||||
):
|
||||
self.ensure_key(guild_id, member_id, key, next)
|
||||
# Get if the key exists
|
||||
self.execute(
|
||||
"""
|
||||
UPDATE chain
|
||||
SET weight = weight + :weight
|
||||
WHERE user = (SELECT id FROM user WHERE guild_id = :guild_id AND member_id = :member_id)
|
||||
AND value = :key
|
||||
AND next = :next
|
||||
""",
|
||||
{
|
||||
"guild_id": guild_id,
|
||||
"member_id": member_id,
|
||||
"key": key,
|
||||
"next": next,
|
||||
"weight": weight,
|
||||
},
|
||||
)
|
||||
|
||||
def get(self, guild_id: int, member_id: int, key: str) -> dict[str, int]:
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT next, weight
|
||||
FROM chain
|
||||
WHERE
|
||||
user = (SELECT id FROM user WHERE guild_id = ? AND member_id = ?)
|
||||
AND value = ?
|
||||
""",
|
||||
(guild_id, member_id, key),
|
||||
)
|
||||
return {next: weight for next, weight in cursor.fetchall()}
|
||||
|
||||
def generate(self, guild_id: int, member_id: int) -> Optional[str]:
|
||||
user_id = self.get_user_id(guild_id, member_id)
|
||||
if not user_id:
|
||||
return None
|
||||
words: List[str] = []
|
||||
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM chain
|
||||
WHERE user = ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
node = cursor.fetchone()[0].split(" ")
|
||||
words += node
|
||||
|
||||
next: str = self.choose_next(guild_id, member_id, " ".join(node))
|
||||
while next:
|
||||
words += [next]
|
||||
node = [*node[1:], next]
|
||||
next = self.choose_next(guild_id, member_id, " ".join(node))
|
||||
return " ".join(words)
|
||||
|
||||
def choose_next(self, guild_id: int, member_id: int, head: str) -> str:
|
||||
choices = self.get(guild_id, member_id, head)
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
if not words:
|
||||
return ""
|
||||
return random.choices(words, weights)[0]
|
||||
|
||||
def generate_all(self, guild_id: int) -> Optional[str]:
|
||||
words: List[str] = []
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM chain
|
||||
WHERE user IN (SELECT user.id FROM user WHERE guild_id = ?)
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(guild_id,),
|
||||
)
|
||||
first = cursor.fetchone()
|
||||
if not first:
|
||||
return None
|
||||
node = first[0].split(" ")
|
||||
words += node
|
||||
next: str = self.all_choose_next(guild_id, " ".join(node))
|
||||
while next:
|
||||
words += [next]
|
||||
node = [*node[1:], next]
|
||||
next = self.all_choose_next(guild_id, " ".join(node))
|
||||
return " ".join(words)
|
||||
|
||||
def all_choose_next(self, guild_id: int, head: str) -> str:
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT next, weight
|
||||
FROM chain
|
||||
WHERE value = ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(head,),
|
||||
)
|
||||
choices: DefaultDict[str, int] = defaultdict(lambda: 0)
|
||||
# Collapse all choices by weight
|
||||
for next, weight in cursor.fetchall():
|
||||
choices[next] += weight
|
||||
if not choices:
|
||||
return ""
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
return random.choices(words, weights)[0]
|
||||
Reference in New Issue
Block a user