from asyncio import Event from datetime import datetime import logging import math import random import sqlite3 import discord from discord import Guild, TextChannel from discord.abc import GuildChannel from . import markov from . import util log = logging.getLogger(__name__) # This is the minimum value for a discord snowflake timestamp JAN_1_2015 = 1420070400.0 MAX_MESSAGE_LEN = 2000 class Client(discord.Client): def __init__(self, *, db: sqlite3.Connection, chain: markov.Chain, **kwargs): super().__init__(**kwargs) self.chain = chain self.db = db self.__chains_ready: dict[int, Event] = {} def execute(self, *args, **kwargs): "Execute a database query." cursor = self.db.cursor() cursor.execute(*args, **kwargs) return list(cursor.fetchall()) def track_channel(self, guild_id: int, channel_id: int): if self.is_channel_tracked(channel_id): return self.execute( "INSERT INTO channel (guild_id, channel_id) VALUES (?, ?)", (guild_id, channel_id), ) self.db.commit() def is_channel_tracked(self, channel_id: int) -> bool: "Gets whether the database is tracking a specified guild." result = self.execute( "SELECT COUNT(*) FROM channel WHERE channel_id = ?", (channel_id,), ) return bool(result[0][0]) def is_listening_to_user(self, guild_id: int, member_id: int) -> bool: if result := self.execute( "SELECT listen FROM user WHERE guild_id = ? AND member_id = ?", (guild_id, member_id), ): return bool(result[0][0]) else: # By default we listen to all users return True def get_last_tracked_message(self, channel_id: int) -> datetime: if result := self.execute( "SELECT last_message FROM channel WHERE channel_id = ?", (channel_id,) ): return util.utc_from_timestamp(result[0][0]) else: return util.utc_from_timestamp(JAN_1_2015) def set_last_tracked_message(self, channel_id: int, last_message: datetime): self.execute( "UPDATE channel SET last_message = ? WHERE channel_id = ?", (last_message.timestamp(), channel_id), ) self.db.commit() async def on_ready(self): log.info("Synchronizing markov chains with guilds") for guild in self.guilds: await self.synchronize_guild(guild) log.info("Ready to start serving messages") async def on_guild_channel_create(self, channel: GuildChannel): # only handle text channels if not isinstance(channel, TextChannel): return await self.synchronize_channel(channel) async def on_guild_join(self, guild: Guild): await self.synchronize_guild(guild) async def synchronize_guild(self, guild: Guild): log.info("Synchronizing guild %s (id: %s)", guild.name, guild.id) if guild.id not in self.__chains_ready: self.__chains_ready[guild.id] = Event() # guild ID is not ready self.__chains_ready[guild.id].clear() # Track guild if necessary for channel in guild.text_channels: await self.synchronize_channel(channel) # guild ID is ready self.__chains_ready[guild.id].set() async def synchronize_channel(self, channel: TextChannel): guild = channel.guild # Make sure this channel is tracked self.track_channel(guild.id, channel.id) # In each channel, all messages sent since the last tracked message # was recorded last_message = self.get_last_tracked_message(channel.id) log.debug( "Guild %s (%s): channel %s (%s): fetching messages and training model", guild.id, guild.name, channel.id, channel.name, ) count = 0 skipped = 0 async for message in channel.history(after=last_message, limit=None): if message.author.bot: continue if message.content.strip().startswith("!"): skipped += 1 continue count += 1 self.chain.add(guild.id, message.author.id, message.content, commit=False) if count % 1000 == 0: log.debug( "%s (%s) %s (%s) - %s messages", guild.id, guild.name, channel.id, channel.name, count, ) self.db.commit() log.info( "Guild %s (%s): channel %s (%s): synchronized %s messages (skipped %s messages that looked like commands)", guild.id, guild.name, channel.id, channel.name, count, skipped, ) if channel.last_message_id: # Update the last message timestamp in the tracking database, if # there's a last message available # NOTE: there is a channel.last_message property, but it's unreliable. # fetch_message(last_message_id) will always yield a # message if there's a valid ID message = await channel.fetch_message(channel.last_message_id) self.set_last_tracked_message(channel.id, message.created_at) async def on_message(self, message: discord.Message): if message.author.bot: return guild = message.guild if not guild: return if guild.id not in self.__chains_ready: return if message.author.bot: return # Do commands text = message.content.strip() parts = text.split() if parts and parts[0] == "!markov": argc = len(parts) if argc == 1: return elif parts[1] == "help": help_messages = [ "Markov is filmed in front of a live studio audience.", "Help! I'm having a heart attack!", "Where is that gun?", "I'm walkin' heah!", "Any resemblance to persons, living or fictional, is entirely coincidental.", "ge8hg809wga987haw4go9hag897hagndfnam3n342ui128", "alek die", ] lines = [ random.choice(help_messages), "", "`!markov help` - idiot", "`!markov force` - force markov to say something", "`!markov force username` - force markov to say something, using someone else's words", "`!markov trigger` - synonym for `!markov force`", f"`!markov chance n` - set your reply chance to some number between 0.0 and {self.chain.reply_chance}", "`!markov listen on|off` - tell markov to start or stop listening to what you say", "`!markov all` - force markov to say something based off of everything said in the server", "", "Also, kill yourself!", ] await message.reply("\n".join(lines)) elif parts[1] in ("force", "trigger"): # get all mentions, and merge their markov chains together. # or if there aren't any args then use the sender's chain. member_id = ( message.mentions[0].id if message.mentions else message.author.id ) line = self.chain.generate(guild.id, member_id) if line: await message.reply(line[:MAX_MESSAGE_LEN]) elif parts[1] == "all": line = self.chain.generate_all(guild.id) if line: await message.reply(line[:MAX_MESSAGE_LEN]) elif parts[1] == "listen": if argc == 2: listening = self.chain.get_user_listen(guild.id, message.author.id) if listening: await message.reply("Markov is listening to your messages") else: await message.reply("Markov is not listening to your messages") elif parts[2].lower() == "on": self.chain.set_user_listen(guild.id, message.author.id, True) await message.reply( "Markov is now listening to everything you say. Turn this off with `!markov listen off`" ) elif parts[2].lower() == "off": self.chain.set_user_listen(guild.id, message.author.id, False) await message.reply( "Markov is no longer listening to everything you say, but you can still use commands. To resume listening, use `!markov listen on`" ) elif parts[1] == "chance": if argc > 2: try: chance = float(parts[2]) if math.isnan(chance): raise ValueError() chance = min(self.chain.reply_chance, max(chance, 0.0)) assert 0.0 <= chance <= self.chain.reply_chance self.chain.set_reply_chance(guild.id, message.author.id, chance) except ValueError: pass chance = self.chain.get_reply_chance(guild.id, message.author.id) await message.reply(f"Reply chance: {chance}") elif text and text[0] != "!": if not self.is_listening_to_user(guild.id, message.author.id): return # The bot is not ready to train the model yet, or is probably already listening to this message. if not self.__chains_ready[guild.id].is_set(): return # Train the model log.debug("%s: %s", message.author, message.content) self.chain.add(guild.id, message.author.id, message.content) # Update the channel's last_message time self.set_last_tracked_message(message.channel.id, message.created_at) # Speak your mind, markov chance = self.chain.get_reply_chance(guild.id, message.author.id) if random.random() < chance: line = self.chain.generate(guild.id, message.author.id) if line: await message.reply(line)