2022-10-22 18:00:52 -07:00
|
|
|
from asyncio import Event
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
import logging
|
|
|
|
|
import math
|
|
|
|
|
import random
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
|
|
|
|
import discord
|
|
|
|
|
from discord import Guild, TextChannel
|
|
|
|
|
from discord.abc import GuildChannel
|
|
|
|
|
|
|
|
|
|
from . import markov
|
|
|
|
|
from . import util
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
# This is the minimum value for a discord snowflake timestamp
|
|
|
|
|
JAN_1_2015 = 1420070400.0
|
|
|
|
|
MAX_MESSAGE_LEN = 2000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Client(discord.Client):
|
|
|
|
|
def __init__(self, *, db: sqlite3.Connection, chain: markov.Chain, **kwargs):
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
self.chain = chain
|
|
|
|
|
self.db = db
|
|
|
|
|
self.__chains_ready: dict[int, Event] = {}
|
|
|
|
|
|
|
|
|
|
def execute(self, *args, **kwargs):
|
|
|
|
|
"Execute a database query."
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
|
|
cursor.execute(*args, **kwargs)
|
|
|
|
|
return list(cursor.fetchall())
|
|
|
|
|
|
|
|
|
|
def track_channel(self, guild_id: int, channel_id: int):
|
|
|
|
|
if self.is_channel_tracked(channel_id):
|
|
|
|
|
return
|
|
|
|
|
self.execute(
|
|
|
|
|
"INSERT INTO channel (guild_id, channel_id) VALUES (?, ?)",
|
|
|
|
|
(guild_id, channel_id),
|
|
|
|
|
)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
|
|
|
|
|
def is_channel_tracked(self, channel_id: int) -> bool:
|
|
|
|
|
"Gets whether the database is tracking a specified guild."
|
|
|
|
|
result = self.execute(
|
|
|
|
|
"SELECT COUNT(*) FROM channel WHERE channel_id = ?",
|
|
|
|
|
(channel_id,),
|
|
|
|
|
)
|
|
|
|
|
return bool(result[0][0])
|
|
|
|
|
|
|
|
|
|
def is_listening_to_user(self, guild_id: int, member_id: int) -> bool:
|
|
|
|
|
if result := self.execute(
|
|
|
|
|
"SELECT listen FROM user WHERE guild_id = ? AND member_id = ?",
|
|
|
|
|
(guild_id, member_id),
|
|
|
|
|
):
|
|
|
|
|
return bool(result[0][0])
|
|
|
|
|
else:
|
|
|
|
|
# By default we listen to all users
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def get_last_tracked_message(self, channel_id: int) -> datetime:
|
|
|
|
|
if result := self.execute(
|
|
|
|
|
"SELECT last_message FROM channel WHERE channel_id = ?", (channel_id,)
|
|
|
|
|
):
|
|
|
|
|
return util.utc_from_timestamp(result[0][0])
|
|
|
|
|
else:
|
|
|
|
|
return util.utc_from_timestamp(JAN_1_2015)
|
|
|
|
|
|
|
|
|
|
def set_last_tracked_message(self, channel_id: int, last_message: datetime):
|
|
|
|
|
self.execute(
|
|
|
|
|
"UPDATE channel SET last_message = ? WHERE channel_id = ?",
|
|
|
|
|
(last_message.timestamp(), channel_id),
|
|
|
|
|
)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
|
|
|
|
|
async def on_ready(self):
|
|
|
|
|
log.info("Synchronizing markov chains with guilds")
|
|
|
|
|
for guild in self.guilds:
|
|
|
|
|
await self.synchronize_guild(guild)
|
|
|
|
|
log.info("Ready to start serving messages")
|
|
|
|
|
|
|
|
|
|
async def on_guild_channel_create(self, channel: GuildChannel):
|
|
|
|
|
# only handle text channels
|
|
|
|
|
if not isinstance(channel, TextChannel):
|
|
|
|
|
return
|
|
|
|
|
await self.synchronize_channel(channel)
|
|
|
|
|
|
|
|
|
|
async def on_guild_join(self, guild: Guild):
|
|
|
|
|
await self.synchronize_guild(guild)
|
|
|
|
|
|
|
|
|
|
async def synchronize_guild(self, guild: Guild):
|
|
|
|
|
log.info("Synchronizing guild %s (id: %s)", guild.name, guild.id)
|
|
|
|
|
|
|
|
|
|
if guild.id not in self.__chains_ready:
|
|
|
|
|
self.__chains_ready[guild.id] = Event()
|
|
|
|
|
|
|
|
|
|
# guild ID is not ready
|
|
|
|
|
self.__chains_ready[guild.id].clear()
|
|
|
|
|
|
|
|
|
|
# Track guild if necessary
|
|
|
|
|
for channel in guild.text_channels:
|
|
|
|
|
await self.synchronize_channel(channel)
|
|
|
|
|
|
|
|
|
|
# guild ID is ready
|
|
|
|
|
self.__chains_ready[guild.id].set()
|
|
|
|
|
|
|
|
|
|
async def synchronize_channel(self, channel: TextChannel):
|
|
|
|
|
guild = channel.guild
|
|
|
|
|
|
|
|
|
|
# Make sure this channel is tracked
|
|
|
|
|
self.track_channel(guild.id, channel.id)
|
|
|
|
|
|
|
|
|
|
# In each channel, all messages sent since the last tracked message
|
|
|
|
|
# was recorded
|
|
|
|
|
last_message = self.get_last_tracked_message(channel.id)
|
|
|
|
|
log.debug(
|
2022-10-22 18:47:09 -07:00
|
|
|
"Guild %s (%s): channel %s (%s): fetching messages and training model",
|
2022-10-22 18:00:52 -07:00
|
|
|
guild.id,
|
2022-10-22 18:47:09 -07:00
|
|
|
guild.name,
|
2022-10-22 18:00:52 -07:00
|
|
|
channel.id,
|
2022-10-22 18:47:09 -07:00
|
|
|
channel.name,
|
2022-10-22 18:00:52 -07:00
|
|
|
)
|
|
|
|
|
count = 0
|
|
|
|
|
skipped = 0
|
|
|
|
|
async for message in channel.history(after=last_message):
|
|
|
|
|
if message.author.bot:
|
|
|
|
|
continue
|
|
|
|
|
if message.content.strip().startswith("!"):
|
|
|
|
|
skipped += 1
|
|
|
|
|
continue
|
|
|
|
|
count += 1
|
|
|
|
|
self.chain.add(guild.id, message.author.id, message.content, commit=False)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
log.info(
|
2022-10-22 18:47:09 -07:00
|
|
|
"Guild %s (%s): channel %s (%s): synchronized %s messages (skipped %s messages that looked like commands)",
|
2022-10-22 18:00:52 -07:00
|
|
|
guild.id,
|
2022-10-22 18:47:09 -07:00
|
|
|
guild.name,
|
2022-10-22 18:00:52 -07:00
|
|
|
channel.id,
|
2022-10-22 18:47:09 -07:00
|
|
|
channel.name,
|
2022-10-22 18:00:52 -07:00
|
|
|
count,
|
|
|
|
|
skipped,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if channel.last_message_id:
|
|
|
|
|
# Update the last message timestamp in the tracking database, if
|
|
|
|
|
# there's a last message available
|
|
|
|
|
# NOTE: there is a channel.last_message property, but it's unreliable.
|
|
|
|
|
# fetch_message(last_message_id) will always yield a
|
|
|
|
|
# message if there's a valid ID
|
|
|
|
|
message = await channel.fetch_message(channel.last_message_id)
|
|
|
|
|
self.set_last_tracked_message(channel.id, message.created_at)
|
|
|
|
|
|
|
|
|
|
async def on_message(self, message: discord.Message):
|
|
|
|
|
if message.author.bot:
|
|
|
|
|
return
|
|
|
|
|
guild = message.guild
|
|
|
|
|
|
|
|
|
|
if not guild:
|
|
|
|
|
return
|
|
|
|
|
if guild.id not in self.__chains_ready:
|
|
|
|
|
return
|
|
|
|
|
if message.author.bot:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Do commands
|
|
|
|
|
text = message.content.strip()
|
|
|
|
|
parts = text.split()
|
|
|
|
|
|
|
|
|
|
if parts and parts[0] == "!markov":
|
|
|
|
|
argc = len(parts)
|
|
|
|
|
if argc == 1:
|
|
|
|
|
return
|
|
|
|
|
elif parts[1] == "help":
|
|
|
|
|
help_messages = [
|
|
|
|
|
"Markov is filmed in front of a live studio audience.",
|
|
|
|
|
"Help! I'm having a heart attack!",
|
|
|
|
|
"Where is that gun?",
|
|
|
|
|
"I'm walkin' heah!",
|
|
|
|
|
"Any resemblance to persons, living or fictional, is entirely coincidental.",
|
|
|
|
|
"ge8hg809wga987haw4go9hag897hagndfnam3n342ui128",
|
|
|
|
|
"alek die",
|
|
|
|
|
]
|
|
|
|
|
lines = [
|
|
|
|
|
random.choice(help_messages),
|
|
|
|
|
"",
|
|
|
|
|
"`!markov help` - idiot",
|
|
|
|
|
"`!markov force` - force markov to say something",
|
|
|
|
|
"`!markov force username` - force markov to say something, using someone else's words",
|
|
|
|
|
"`!markov trigger` - synonym for `!markov force`",
|
|
|
|
|
f"`!markov chance n` - set your reply chance to some number between 0.0 and {self.chain.reply_chance}",
|
|
|
|
|
"`!markov listen on|off` - tell markov to start or stop listening to what you say",
|
|
|
|
|
"`!markov all` - force markov to say something based off of everything said in the server",
|
|
|
|
|
"",
|
|
|
|
|
"Also, kill yourself!",
|
|
|
|
|
]
|
|
|
|
|
await message.reply("\n".join(lines))
|
|
|
|
|
elif parts[1] in ("force", "trigger"):
|
|
|
|
|
# get all mentions, and merge their markov chains together.
|
|
|
|
|
# or if there aren't any args then use the sender's chain.
|
|
|
|
|
member_id = (
|
|
|
|
|
message.mentions[0].id if message.mentions else message.author.id
|
|
|
|
|
)
|
|
|
|
|
line = self.chain.generate(guild.id, member_id)
|
|
|
|
|
if line:
|
|
|
|
|
await message.reply(line[:MAX_MESSAGE_LEN])
|
|
|
|
|
elif parts[1] == "all":
|
|
|
|
|
line = self.chain.generate_all(guild.id)
|
|
|
|
|
if line:
|
|
|
|
|
await message.reply(line[:MAX_MESSAGE_LEN])
|
|
|
|
|
elif parts[1] == "listen":
|
|
|
|
|
if argc == 2:
|
|
|
|
|
listening = self.chain.get_user_listen(guild.id, message.author.id)
|
|
|
|
|
if listening:
|
|
|
|
|
await message.reply("Markov is listening to your messages")
|
|
|
|
|
else:
|
|
|
|
|
await message.reply("Markov is not listening to your messages")
|
|
|
|
|
elif parts[2].lower() == "on":
|
|
|
|
|
self.chain.set_user_listen(guild.id, message.author.id, True)
|
|
|
|
|
await message.reply(
|
|
|
|
|
"Markov is now listening to everything you say. Turn this off with `!markov listen off`"
|
|
|
|
|
)
|
|
|
|
|
elif parts[2].lower() == "off":
|
|
|
|
|
self.chain.set_user_listen(guild.id, message.author.id, False)
|
|
|
|
|
await message.reply(
|
|
|
|
|
"Markov is no longer listening to everything you say, but you can still use commands. To resume listening, use `!markov listen on`"
|
|
|
|
|
)
|
|
|
|
|
elif parts[1] == "chance":
|
|
|
|
|
if argc > 2:
|
|
|
|
|
try:
|
|
|
|
|
chance = float(parts[2])
|
|
|
|
|
if math.isnan(chance):
|
|
|
|
|
raise ValueError()
|
|
|
|
|
chance = min(self.chain.reply_chance, max(chance, 0.0))
|
|
|
|
|
assert 0.0 <= chance <= self.chain.reply_chance
|
|
|
|
|
self.chain.set_reply_chance(guild.id, message.author.id, chance)
|
|
|
|
|
except ValueError:
|
|
|
|
|
pass
|
|
|
|
|
chance = self.chain.get_reply_chance(guild.id, message.author.id)
|
|
|
|
|
await message.reply(f"Reply chance: {chance}")
|
|
|
|
|
elif text and text[0] != "!":
|
|
|
|
|
if not self.is_listening_to_user(guild.id, message.author.id):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# The bot is not ready to train the model yet, or is probably already listening to this message.
|
|
|
|
|
if not self.__chains_ready[guild.id].is_set():
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Train the model
|
|
|
|
|
log.debug("%s: %s", message.author, message.content)
|
|
|
|
|
self.chain.add(guild.id, message.author.id, message.content)
|
|
|
|
|
|
|
|
|
|
# Update the channel's last_message time
|
|
|
|
|
self.set_last_tracked_message(message.channel.id, message.created_at)
|
2022-10-22 18:47:09 -07:00
|
|
|
|
|
|
|
|
# Speak your mind, markov
|
|
|
|
|
chance = self.chain.get_reply_chance(guild.id, message.author.id)
|
|
|
|
|
if random.random() < chance:
|
|
|
|
|
line = self.chain.generate(guild.id, message.author.id)
|
|
|
|
|
if line:
|
|
|
|
|
await message.reply(line)
|