Initial commit
The bot is currently working under 3.10, need to check out 3.9 next. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
0
discord_markov/__init__.py
Normal file
0
discord_markov/__init__.py
Normal file
89
discord_markov/__main__.py
Normal file
89
discord_markov/__main__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
import discord
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .client import Client
|
||||
from .markov import Chain
|
||||
|
||||
################################################################################
|
||||
# Environment
|
||||
################################################################################
|
||||
load_dotenv()
|
||||
# More common values
|
||||
token = os.getenv("TOKEN")
|
||||
log = os.getenv("LOG", "stderr")
|
||||
loglevel = os.getenv("LOGLEVEL", "INFO")
|
||||
db_path = Path(os.getenv("DB_PATH", "markov.db"))
|
||||
|
||||
# Less common values
|
||||
sql_path = Path(os.getenv("SQL_PATH", "db.sql"))
|
||||
|
||||
if not token:
|
||||
print("ERROR: TOKEN environment variable not set. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
if not sql_path.exists:
|
||||
print("ERROR: could not find database SQL file")
|
||||
sys.exit(1)
|
||||
|
||||
if loglevel.upper() not in ("DEBUG", "INFO", "WARNING", "WARN", "ERROR", "CRITICAL"):
|
||||
print(f"WARNING: unknown loglevel {loglevel} - defaulting to INFO")
|
||||
loglevel = "INFO"
|
||||
|
||||
|
||||
################################################################################
|
||||
# Logging setup
|
||||
################################################################################
|
||||
handler: logging.Handler
|
||||
|
||||
if log is None or log.lower() == "stderr":
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
elif log.lower() == "stdout":
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
else:
|
||||
# 5 megabytes per log
|
||||
MAX_LOG_SIZE = 5 * (2**20)
|
||||
# Keep up to 5 logs
|
||||
MAX_LOG_BACKUP = 5
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log, maxBytes=MAX_LOG_SIZE, backupCount=MAX_LOG_BACKUP
|
||||
)
|
||||
|
||||
# logging.basicConfig(
|
||||
# handlers=[handler],
|
||||
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
# level=getattr(logging, loglevel),
|
||||
# )
|
||||
discord.utils.setup_logging(handler=handler, level=getattr(logging, loglevel))
|
||||
|
||||
db = sqlite3.connect(db_path)
|
||||
|
||||
################################################################################
|
||||
# Create the database
|
||||
################################################################################
|
||||
cursor = db.cursor()
|
||||
with open(sql_path) as fp:
|
||||
cursor.executescript(fp.read())
|
||||
cursor.close()
|
||||
db.commit()
|
||||
|
||||
################################################################################
|
||||
# Set up markov chain
|
||||
################################################################################
|
||||
chain = Chain(order=2, reply_chance=0.01, db=db)
|
||||
|
||||
################################################################################
|
||||
# Client setup and bot run
|
||||
################################################################################
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
|
||||
client = Client(db=db, chain=chain, intents=intents)
|
||||
client.run(token)
|
||||
250
discord_markov/client.py
Normal file
250
discord_markov/client.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from asyncio import Event
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import sqlite3
|
||||
|
||||
import discord
|
||||
from discord import Guild, TextChannel
|
||||
from discord.abc import GuildChannel
|
||||
|
||||
from . import markov
|
||||
from . import util
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# This is the minimum value for a discord snowflake timestamp
|
||||
JAN_1_2015 = 1420070400.0
|
||||
MAX_MESSAGE_LEN = 2000
|
||||
|
||||
|
||||
class Client(discord.Client):
|
||||
def __init__(self, *, db: sqlite3.Connection, chain: markov.Chain, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chain = chain
|
||||
self.db = db
|
||||
self.__chains_ready: dict[int, Event] = {}
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
"Execute a database query."
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(*args, **kwargs)
|
||||
return list(cursor.fetchall())
|
||||
|
||||
def track_channel(self, guild_id: int, channel_id: int):
|
||||
if self.is_channel_tracked(channel_id):
|
||||
return
|
||||
self.execute(
|
||||
"INSERT INTO channel (guild_id, channel_id) VALUES (?, ?)",
|
||||
(guild_id, channel_id),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
def is_channel_tracked(self, channel_id: int) -> bool:
|
||||
"Gets whether the database is tracking a specified guild."
|
||||
result = self.execute(
|
||||
"SELECT COUNT(*) FROM channel WHERE channel_id = ?",
|
||||
(channel_id,),
|
||||
)
|
||||
return bool(result[0][0])
|
||||
|
||||
def is_listening_to_user(self, guild_id: int, member_id: int) -> bool:
|
||||
if result := self.execute(
|
||||
"SELECT listen FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return bool(result[0][0])
|
||||
else:
|
||||
# By default we listen to all users
|
||||
return True
|
||||
|
||||
def get_last_tracked_message(self, channel_id: int) -> datetime:
|
||||
if result := self.execute(
|
||||
"SELECT last_message FROM channel WHERE channel_id = ?", (channel_id,)
|
||||
):
|
||||
return util.utc_from_timestamp(result[0][0])
|
||||
else:
|
||||
return util.utc_from_timestamp(JAN_1_2015)
|
||||
|
||||
def set_last_tracked_message(self, channel_id: int, last_message: datetime):
|
||||
self.execute(
|
||||
"UPDATE channel SET last_message = ? WHERE channel_id = ?",
|
||||
(last_message.timestamp(), channel_id),
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
async def on_ready(self):
|
||||
log.info("Synchronizing markov chains with guilds")
|
||||
for guild in self.guilds:
|
||||
await self.synchronize_guild(guild)
|
||||
log.info("Ready to start serving messages")
|
||||
|
||||
async def on_guild_channel_create(self, channel: GuildChannel):
|
||||
# only handle text channels
|
||||
if not isinstance(channel, TextChannel):
|
||||
return
|
||||
await self.synchronize_channel(channel)
|
||||
|
||||
async def on_guild_join(self, guild: Guild):
|
||||
await self.synchronize_guild(guild)
|
||||
|
||||
async def synchronize_guild(self, guild: Guild):
|
||||
log.info("Synchronizing guild %s (id: %s)", guild.name, guild.id)
|
||||
|
||||
if guild.id not in self.__chains_ready:
|
||||
self.__chains_ready[guild.id] = Event()
|
||||
|
||||
# guild ID is not ready
|
||||
self.__chains_ready[guild.id].clear()
|
||||
|
||||
# Track guild if necessary
|
||||
for channel in guild.text_channels:
|
||||
await self.synchronize_channel(channel)
|
||||
|
||||
# guild ID is ready
|
||||
self.__chains_ready[guild.id].set()
|
||||
|
||||
async def synchronize_channel(self, channel: TextChannel):
|
||||
guild = channel.guild
|
||||
|
||||
# Make sure this channel is tracked
|
||||
self.track_channel(guild.id, channel.id)
|
||||
|
||||
# In each channel, all messages sent since the last tracked message
|
||||
# was recorded
|
||||
last_message = self.get_last_tracked_message(channel.id)
|
||||
log.debug(
|
||||
"Guild %s: channel %s: fetching messages and training model",
|
||||
guild.id,
|
||||
channel.id,
|
||||
)
|
||||
count = 0
|
||||
skipped = 0
|
||||
async for message in channel.history(after=last_message):
|
||||
if message.author.bot:
|
||||
continue
|
||||
if message.content.strip().startswith("!"):
|
||||
skipped += 1
|
||||
continue
|
||||
count += 1
|
||||
self.chain.add(guild.id, message.author.id, message.content, commit=False)
|
||||
self.db.commit()
|
||||
log.info(
|
||||
"Guild %s: channel %s: synchronized %s messages (skipped %s messages that looked like commands)",
|
||||
guild.id,
|
||||
channel.id,
|
||||
count,
|
||||
skipped,
|
||||
)
|
||||
|
||||
if channel.last_message_id:
|
||||
# Update the last message timestamp in the tracking database, if
|
||||
# there's a last message available
|
||||
# NOTE: there is a channel.last_message property, but it's unreliable.
|
||||
# fetch_message(last_message_id) will always yield a
|
||||
# message if there's a valid ID
|
||||
message = await channel.fetch_message(channel.last_message_id)
|
||||
self.set_last_tracked_message(channel.id, message.created_at)
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
if message.author.bot:
|
||||
return
|
||||
guild = message.guild
|
||||
|
||||
if not guild:
|
||||
return
|
||||
if guild.id not in self.__chains_ready:
|
||||
return
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Do commands
|
||||
text = message.content.strip()
|
||||
parts = text.split()
|
||||
|
||||
if parts and parts[0] == "!markov":
|
||||
argc = len(parts)
|
||||
if argc == 1:
|
||||
return
|
||||
elif parts[1] == "help":
|
||||
help_messages = [
|
||||
"Markov is filmed in front of a live studio audience.",
|
||||
"Help! I'm having a heart attack!",
|
||||
"Where is that gun?",
|
||||
"I'm walkin' heah!",
|
||||
"Any resemblance to persons, living or fictional, is entirely coincidental.",
|
||||
"ge8hg809wga987haw4go9hag897hagndfnam3n342ui128",
|
||||
"alek die",
|
||||
]
|
||||
lines = [
|
||||
random.choice(help_messages),
|
||||
"",
|
||||
"`!markov help` - idiot",
|
||||
"`!markov force` - force markov to say something",
|
||||
"`!markov force username` - force markov to say something, using someone else's words",
|
||||
"`!markov trigger` - synonym for `!markov force`",
|
||||
f"`!markov chance n` - set your reply chance to some number between 0.0 and {self.chain.reply_chance}",
|
||||
"`!markov listen on|off` - tell markov to start or stop listening to what you say",
|
||||
"`!markov all` - force markov to say something based off of everything said in the server",
|
||||
"",
|
||||
"Also, kill yourself!",
|
||||
]
|
||||
await message.reply("\n".join(lines))
|
||||
elif parts[1] in ("force", "trigger"):
|
||||
# get all mentions, and merge their markov chains together.
|
||||
# or if there aren't any args then use the sender's chain.
|
||||
member_id = (
|
||||
message.mentions[0].id if message.mentions else message.author.id
|
||||
)
|
||||
line = self.chain.generate(guild.id, member_id)
|
||||
if line:
|
||||
await message.reply(line[:MAX_MESSAGE_LEN])
|
||||
elif parts[1] == "all":
|
||||
line = self.chain.generate_all(guild.id)
|
||||
if line:
|
||||
await message.reply(line[:MAX_MESSAGE_LEN])
|
||||
elif parts[1] == "listen":
|
||||
if argc == 2:
|
||||
listening = self.chain.get_user_listen(guild.id, message.author.id)
|
||||
if listening:
|
||||
await message.reply("Markov is listening to your messages")
|
||||
else:
|
||||
await message.reply("Markov is not listening to your messages")
|
||||
elif parts[2].lower() == "on":
|
||||
self.chain.set_user_listen(guild.id, message.author.id, True)
|
||||
await message.reply(
|
||||
"Markov is now listening to everything you say. Turn this off with `!markov listen off`"
|
||||
)
|
||||
elif parts[2].lower() == "off":
|
||||
self.chain.set_user_listen(guild.id, message.author.id, False)
|
||||
await message.reply(
|
||||
"Markov is no longer listening to everything you say, but you can still use commands. To resume listening, use `!markov listen on`"
|
||||
)
|
||||
elif parts[1] == "chance":
|
||||
if argc > 2:
|
||||
try:
|
||||
chance = float(parts[2])
|
||||
if math.isnan(chance):
|
||||
raise ValueError()
|
||||
chance = min(self.chain.reply_chance, max(chance, 0.0))
|
||||
assert 0.0 <= chance <= self.chain.reply_chance
|
||||
self.chain.set_reply_chance(guild.id, message.author.id, chance)
|
||||
except ValueError:
|
||||
pass
|
||||
chance = self.chain.get_reply_chance(guild.id, message.author.id)
|
||||
await message.reply(f"Reply chance: {chance}")
|
||||
elif text and text[0] != "!":
|
||||
if not self.is_listening_to_user(guild.id, message.author.id):
|
||||
return
|
||||
|
||||
# The bot is not ready to train the model yet, or is probably already listening to this message.
|
||||
if not self.__chains_ready[guild.id].is_set():
|
||||
return
|
||||
|
||||
# Train the model
|
||||
log.debug("%s: %s", message.author, message.content)
|
||||
self.chain.add(guild.id, message.author.id, message.content)
|
||||
|
||||
# Update the channel's last_message time
|
||||
self.set_last_tracked_message(message.channel.id, message.created_at)
|
||||
244
discord_markov/markov.py
Normal file
244
discord_markov/markov.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import random
|
||||
import sqlite3
|
||||
from typing import Any, DefaultDict, List, Optional, Sequence
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chain_inner_default() -> defaultdict[Optional[str], int]:
|
||||
return defaultdict(int)
|
||||
|
||||
|
||||
def chain_default() -> defaultdict[str, defaultdict[Optional[str], int]]:
|
||||
return defaultdict(chain_inner_default)
|
||||
|
||||
|
||||
def windows(items: Sequence[Any], size: int):
|
||||
if len(items) < size:
|
||||
yield items
|
||||
else:
|
||||
for i in range(0, len(items) - size + 1):
|
||||
yield items[i : i + size]
|
||||
|
||||
|
||||
class Chain:
|
||||
def __init__(self, order: int, reply_chance: float, db: sqlite3.Connection):
|
||||
self.order = order
|
||||
self.reply_chance = reply_chance
|
||||
self.db = db
|
||||
|
||||
def commit(self):
|
||||
self.db.commit()
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(*args, **kwargs)
|
||||
return list(cursor.fetchall())
|
||||
|
||||
def get_user_listen(self, guild_id: int, member_id: int) -> bool:
|
||||
self.ensure_user(guild_id, member_id)
|
||||
if result := self.execute(
|
||||
"SELECT listen FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return bool(result[0][0])
|
||||
else:
|
||||
return True
|
||||
|
||||
def set_user_listen(self, guild_id: int, member_id: int, listen: bool):
|
||||
self.ensure_user(guild_id, member_id)
|
||||
self.execute(
|
||||
"UPDATE user SET listen = ? WHERE guild_id = ? AND member_id = ?",
|
||||
(listen, guild_id, member_id),
|
||||
)
|
||||
self.commit()
|
||||
|
||||
def get_reply_chance(self, guild_id: int, member_id: int) -> float:
|
||||
self.ensure_user(guild_id, member_id)
|
||||
if result := self.execute(
|
||||
"SELECT reply_chance FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return result[0][0]
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def set_reply_chance(self, guild_id: int, member_id: int, chance: float):
|
||||
self.ensure_user(guild_id, member_id)
|
||||
self.execute(
|
||||
"UPDATE user SET reply_chance = ? WHERE guild_id = ? AND member_id = ?",
|
||||
(chance, guild_id, member_id),
|
||||
)
|
||||
self.commit()
|
||||
|
||||
def get_user_id(self, guild_id: int, member_id: int) -> Optional[int]:
|
||||
if result := self.execute(
|
||||
"SELECT id FROM user WHERE guild_id = ? AND member_id = ?",
|
||||
(guild_id, member_id),
|
||||
):
|
||||
return result[0][0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def ensure_user(self, guild_id: int, member_id: int):
|
||||
if self.get_user_id(guild_id, member_id):
|
||||
return
|
||||
self.execute(
|
||||
"INSERT INTO user (guild_id, member_id, reply_chance) VALUES (?, ?, ?)",
|
||||
(guild_id, member_id, self.reply_chance),
|
||||
)
|
||||
|
||||
def ensure_key(self, guild_id: int, member_id: int, key: str, next: str):
|
||||
assert next is not None
|
||||
|
||||
self.ensure_user(guild_id, member_id)
|
||||
if next in self.get(guild_id, member_id, key):
|
||||
return
|
||||
self.execute(
|
||||
"""
|
||||
INSERT INTO chain (user, value, weight, next)
|
||||
VALUES (
|
||||
(SELECT id FROM user WHERE guild_id = ? AND member_id = ?),
|
||||
?, 0, ?
|
||||
)
|
||||
""",
|
||||
(guild_id, member_id, key, next),
|
||||
)
|
||||
|
||||
def add(self, guild_id: int, member_id: int, text: str, commit=True):
|
||||
parts: List[Any] = text.strip().split()
|
||||
if not parts:
|
||||
return
|
||||
for fragment in windows(parts + [""], self.order + 1):
|
||||
head = fragment[0:-1]
|
||||
tail = fragment[-1]
|
||||
key = " ".join(head)
|
||||
self.update_chain(guild_id, member_id, key, tail)
|
||||
if commit:
|
||||
self.commit()
|
||||
|
||||
def update_chain(
|
||||
self,
|
||||
guild_id: int,
|
||||
member_id: int,
|
||||
key: str,
|
||||
next: str,
|
||||
weight: int = 1,
|
||||
):
|
||||
self.ensure_key(guild_id, member_id, key, next)
|
||||
# Get if the key exists
|
||||
self.execute(
|
||||
"""
|
||||
UPDATE chain
|
||||
SET weight = weight + :weight
|
||||
WHERE user = (SELECT id FROM user WHERE guild_id = :guild_id AND member_id = :member_id)
|
||||
AND value = :key
|
||||
AND next = :next
|
||||
""",
|
||||
{
|
||||
"guild_id": guild_id,
|
||||
"member_id": member_id,
|
||||
"key": key,
|
||||
"next": next,
|
||||
"weight": weight,
|
||||
},
|
||||
)
|
||||
|
||||
def get(self, guild_id: int, member_id: int, key: str) -> dict[str, int]:
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT next, weight
|
||||
FROM chain
|
||||
WHERE
|
||||
user = (SELECT id FROM user WHERE guild_id = ? AND member_id = ?)
|
||||
AND value = ?
|
||||
""",
|
||||
(guild_id, member_id, key),
|
||||
)
|
||||
return {next: weight for next, weight in cursor.fetchall()}
|
||||
|
||||
def generate(self, guild_id: int, member_id: int) -> Optional[str]:
|
||||
user_id = self.get_user_id(guild_id, member_id)
|
||||
if not user_id:
|
||||
return None
|
||||
words: List[str] = []
|
||||
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM chain
|
||||
WHERE user = ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
node = cursor.fetchone()[0].split(" ")
|
||||
words += node
|
||||
|
||||
next: str = self.choose_next(guild_id, member_id, " ".join(node))
|
||||
while next:
|
||||
words += [next]
|
||||
node = [*node[1:], next]
|
||||
next = self.choose_next(guild_id, member_id, " ".join(node))
|
||||
return " ".join(words)
|
||||
|
||||
def choose_next(self, guild_id: int, member_id: int, head: str) -> str:
|
||||
choices = self.get(guild_id, member_id, head)
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
if not words:
|
||||
return ""
|
||||
return random.choices(words, weights)[0]
|
||||
|
||||
def generate_all(self, guild_id: int) -> Optional[str]:
|
||||
words: List[str] = []
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM chain
|
||||
WHERE user IN (SELECT user.id FROM user WHERE guild_id = ?)
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(guild_id,),
|
||||
)
|
||||
first = cursor.fetchone()
|
||||
if not first:
|
||||
return None
|
||||
node = first[0].split(" ")
|
||||
words += node
|
||||
next: str = self.all_choose_next(guild_id, " ".join(node))
|
||||
while next:
|
||||
words += [next]
|
||||
node = [*node[1:], next]
|
||||
next = self.all_choose_next(guild_id, " ".join(node))
|
||||
return " ".join(words)
|
||||
|
||||
def all_choose_next(self, guild_id: int, head: str) -> str:
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT next, weight
|
||||
FROM chain
|
||||
WHERE value = ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 1
|
||||
""",
|
||||
(head,),
|
||||
)
|
||||
choices: DefaultDict[str, int] = defaultdict(lambda: 0)
|
||||
# Collapse all choices by weight
|
||||
for next, weight in cursor.fetchall():
|
||||
choices[next] += weight
|
||||
if not choices:
|
||||
return ""
|
||||
words = list(choices.keys())
|
||||
weights = list(choices.values())
|
||||
return random.choices(words, weights)[0]
|
||||
9
discord_markov/util.py
Normal file
9
discord_markov/util.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def utc_from_timestamp(timestamp: float):
|
||||
return datetime.fromtimestamp(timestamp, timezone.utc)
|
||||
Reference in New Issue
Block a user