Files
discord-markov/discord_markov/client.py

276 lines
11 KiB
Python
Raw Normal View History

from asyncio import Event
from datetime import datetime
import logging
import math
import random
import sqlite3
import discord
from discord import Guild, TextChannel
from discord.abc import GuildChannel
from . import markov
from . import util
log = logging.getLogger(__name__)
# This is the minimum value for a discord snowflake timestamp
JAN_1_2015 = 1420070400.0
MAX_MESSAGE_LEN = 2000
class Client(discord.Client):
def __init__(self, *, db: sqlite3.Connection, chain: markov.Chain, **kwargs):
super().__init__(**kwargs)
self.chain = chain
self.db = db
self.__chains_ready: dict[int, Event] = {}
def execute(self, *args, **kwargs):
"Execute a database query."
cursor = self.db.cursor()
cursor.execute(*args, **kwargs)
return list(cursor.fetchall())
def track_channel(self, guild_id: int, channel_id: int):
if self.is_channel_tracked(channel_id):
return
self.execute(
"INSERT INTO channel (guild_id, channel_id) VALUES (?, ?)",
(guild_id, channel_id),
)
self.db.commit()
def is_channel_tracked(self, channel_id: int) -> bool:
"Gets whether the database is tracking a specified guild."
result = self.execute(
"SELECT COUNT(*) FROM channel WHERE channel_id = ?",
(channel_id,),
)
return bool(result[0][0])
def is_listening_to_user(self, guild_id: int, member_id: int) -> bool:
if result := self.execute(
"SELECT listen FROM user WHERE guild_id = ? AND member_id = ?",
(guild_id, member_id),
):
return bool(result[0][0])
else:
# By default we listen to all users
return True
def get_last_tracked_message(self, channel_id: int) -> datetime:
if result := self.execute(
"SELECT last_message FROM channel WHERE channel_id = ?", (channel_id,)
):
return util.utc_from_timestamp(result[0][0])
else:
return util.utc_from_timestamp(JAN_1_2015)
def set_last_tracked_message(self, channel_id: int, last_message: datetime):
self.execute(
"UPDATE channel SET last_message = ? WHERE channel_id = ?",
(last_message.timestamp(), channel_id),
)
self.db.commit()
async def on_ready(self):
log.info("Synchronizing markov chains with guilds")
for guild in self.guilds:
await self.synchronize_guild(guild)
log.info("Ready to start serving messages")
async def on_guild_channel_create(self, channel: GuildChannel):
# only handle text channels
if not isinstance(channel, TextChannel):
return
await self.synchronize_channel(channel)
async def on_guild_join(self, guild: Guild):
await self.synchronize_guild(guild)
async def synchronize_guild(self, guild: Guild):
log.info("Synchronizing guild %s (id: %s)", guild.name, guild.id)
if guild.id not in self.__chains_ready:
self.__chains_ready[guild.id] = Event()
# guild ID is not ready
self.__chains_ready[guild.id].clear()
# Track guild if necessary
for channel in guild.text_channels:
await self.synchronize_channel(channel)
# guild ID is ready
self.__chains_ready[guild.id].set()
async def synchronize_channel(self, channel: TextChannel):
guild = channel.guild
# Make sure this channel is tracked
self.track_channel(guild.id, channel.id)
# In each channel, all messages sent since the last tracked message
# was recorded
last_message = self.get_last_tracked_message(channel.id)
log.debug(
"Guild %s (%s): channel %s (%s): fetching messages and training model",
guild.id,
guild.name,
channel.id,
channel.name,
)
count = 0
skipped = 0
async for message in channel.history(after=last_message, limit=None):
if message.author.bot:
continue
if message.content.strip().startswith("!"):
skipped += 1
continue
count += 1
self.chain.add(guild.id, message.author.id, message.content, commit=False)
if count % 1000 == 0:
log.debug(
"%s (%s) %s (%s) - %s messages",
guild.id,
guild.name,
channel.id,
channel.name,
count,
)
self.db.commit()
log.info(
"Guild %s (%s): channel %s (%s): synchronized %s messages (skipped %s messages that looked like commands)",
guild.id,
guild.name,
channel.id,
channel.name,
count,
skipped,
)
if channel.last_message_id:
# Update the last message timestamp in the tracking database, if
# there's a last message available
# NOTE: there is a channel.last_message property, but it's unreliable.
# fetch_message(last_message_id) will always yield a
# message if there's a valid ID
try:
message = await channel.fetch_message(channel.last_message_id)
last_message = message.created_at
except discord.errors.NotFound:
# sometimes the last message is just not found. What to heck??
last_message = util.utcnow()
self.set_last_tracked_message(channel.id, last_message)
async def on_message(self, message: discord.Message):
if message.author.bot:
return
guild = message.guild
if not guild:
return
if guild.id not in self.__chains_ready:
return
if message.author.bot:
return
# Do commands
text = message.content.strip()
parts = text.split()
if parts and parts[0] == "!markov":
argc = len(parts)
if argc == 1:
return
elif parts[1] == "help":
help_messages = [
"Markov is filmed in front of a live studio audience.",
"Help! I'm having a heart attack!",
"Where is that gun?",
"I'm walkin' heah!",
"Any resemblance to persons, living or fictional, is entirely coincidental.",
"ge8hg809wga987haw4go9hag897hagndfnam3n342ui128",
"alek die",
]
lines = [
random.choice(help_messages),
"",
"`!markov help` - idiot",
"`!markov force` - force markov to say something",
"`!markov force username` - force markov to say something, using someone else's words",
"`!markov trigger` - synonym for `!markov force`",
f"`!markov chance n` - set your reply chance to some number between 0.0 and {self.chain.reply_chance}",
"`!markov listen on|off` - tell markov to start or stop listening to what you say",
"`!markov all` - force markov to say something based off of everything said in the server",
"",
"Also, kill yourself!",
]
await message.reply("\n".join(lines))
elif parts[1] in ("force", "trigger"):
# get all mentions, and merge their markov chains together.
# or if there aren't any args then use the sender's chain.
member_id = (
message.mentions[0].id if message.mentions else message.author.id
)
line = self.chain.generate(guild.id, member_id)
if line:
await message.reply(line[:MAX_MESSAGE_LEN])
elif parts[1] == "all":
line = self.chain.generate_all(guild.id)
if line:
await message.reply(line[:MAX_MESSAGE_LEN])
elif parts[1] == "listen":
if argc == 2:
listening = self.chain.get_user_listen(guild.id, message.author.id)
if listening:
await message.reply("Markov is listening to your messages")
else:
await message.reply("Markov is not listening to your messages")
elif parts[2].lower() == "on":
self.chain.set_user_listen(guild.id, message.author.id, True)
await message.reply(
"Markov is now listening to everything you say. Turn this off with `!markov listen off`"
)
elif parts[2].lower() == "off":
self.chain.set_user_listen(guild.id, message.author.id, False)
await message.reply(
"Markov is no longer listening to everything you say, but you can still use commands. To resume listening, use `!markov listen on`"
)
elif parts[1] == "chance":
if argc > 2:
try:
chance = float(parts[2])
if math.isnan(chance):
raise ValueError()
chance = min(self.chain.reply_chance, max(chance, 0.0))
assert 0.0 <= chance <= self.chain.reply_chance
self.chain.set_reply_chance(guild.id, message.author.id, chance)
except ValueError:
pass
chance = self.chain.get_reply_chance(guild.id, message.author.id)
await message.reply(f"Reply chance: {chance}")
elif text and text[0] != "!":
if not self.is_listening_to_user(guild.id, message.author.id):
return
# The bot is not ready to train the model yet, or is probably already listening to this message.
if not self.__chains_ready[guild.id].is_set():
return
# Train the model
log.debug("%s: %s", message.author, message.content)
self.chain.add(guild.id, message.author.id, message.content)
# Update the channel's last_message time
self.set_last_tracked_message(message.channel.id, message.created_at)
# Speak your mind, markov
chance = self.chain.get_reply_chance(guild.id, message.author.id)
if random.random() < chance:
line = self.chain.generate(guild.id, message.author.id)
if line:
await message.reply(line)