diff --git a/plugins/wordbot.py b/plugins/wordbot.py new file mode 100644 index 0000000..249ba1f --- /dev/null +++ b/plugins/wordbot.py @@ -0,0 +1,281 @@ +import asyncio +import itertools +import logging +from pathlib import Path +import random +import sqlite3 +import time +from typing import Set + +from asyncirc.protocol import IrcProtocol +from irclib.parser import Prefix +from omnibot.plugin import Plugin + + +log = logging.getLogger(__name__) + + +class Db: + def __init__(self, path: Path): + self.path = path + + def ensure_db(self): + self.path.parent.mkdir(parents=True, exist_ok=True) + with sqlite3.connect(self.path) as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS game ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + start INTEGER NOT NULL, + end INTEGER NOT NULL, + channel VARCHAR(40) NOT NULL + ); + CREATE TABLE IF NOT EXISTS word( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + game INTEGER NOT NULL, + word VARCHAR(40) NOT NULL, + FOREIGN KEY (game) REFERENCES game(id), + UNIQUE (game, word) + ); + CREATE TABLE IF NOT EXISTS score ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + game INTEGER NOT NULL, + word INTEGER NOT NULL, + user VARCHAR(40) NOT NULL, + line VARCHAR(1024) NOT NULL, + FOREIGN KEY (game) REFERENCES game(id), + FOREIGN KEY (word) REFERENCES word(id), + UNIQUE(game, word) + ); + """ + ) + + def current_game(self, channel: str) -> int | None: + self.ensure_db() + with sqlite3.connect(self.path) as conn: + cur = conn.cursor() + cur.execute("SELECT MAX(id) FROM game WHERE channel = ?", (channel,)) + row = cur.fetchone() + if row: + return row[0] + else: + return None + + def is_game_active(self, channel: str) -> bool: + self.ensure_db() + game_id = self.current_game(channel) + if not game_id: + return False + with sqlite3.connect(self.path) as conn: + cur = conn.cursor() + cur.execute("SELECT end FROM game WHERE id = ?", (game_id,)) + row = cur.fetchone() + if row: + now = time.time() + return now < row[0] + else: + return False + + def start_round( + self, channel: str, duration: int, words: Set[str], allow_early_end=False + ): + self.ensure_db() + if self.is_game_active(channel) and not allow_early_end: + # Don't start a new game if you don't have to + raise Exception(f"Wordbot game is already running on {channel}") + start = time.time() + end = start + duration + with sqlite3.connect(self.path) as conn: + conn.execute( + "INSERT INTO game (start, end, channel) VALUES (?, ?, ?)", + (start, end, channel), + ) + # Mass insert some words + game_id = self.current_game(channel) + game_words_iter = zip(itertools.repeat(game_id), words) + with sqlite3.connect(self.path) as conn: + conn.executemany( + "INSERT INTO word (game, word) VALUES (?, ?)", game_words_iter + ) + + def add_score(self, channel: str, user: str, word: str, line: str): + self.ensure_db() + game_id = self.current_game(channel) + if not game_id: + log.warning( + "Tried to add score, but no active wordbot game for channel %s", channel + ) + return + with sqlite3.connect(self.path) as conn: + conn.execute( + """ + INSERT INTO score (game, word, user, line) + VALUES ( + :game_id, + (SELECT word.id FROM word WHERE game = :game_id AND word = :word), + :user, + :line + ) + """, + {"game_id": game_id, "word": word, "user": user, "line": line}, + ) + + def scores(self, channel: str): + # This differs from .leaderboard() by using a specific game ID, rather + # than all games for the channel. + game_id = self.current_game(channel) + with sqlite3.connect(self.path) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT user, COUNT(score.id) AS score + FROM score + JOIN game ON score.game = game.id + WHERE game.id = ? + GROUP BY user + """, + (game_id,), + ) + rows = cur.fetchall() + return {row[0]: row[1] for row in rows} + + def leaderboard(self, channel: str): + # This differs from .scores() by using the game.channel = ?, rather than + # a specific game id. + with sqlite3.connect(self.path) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT user, COUNT(score.id) AS score + FROM score + JOIN game ON score.game = game.id + WHERE game.channel = ? + GROUP BY user + """, + (channel,), + ) + rows = cur.fetchall() + return {row[0]: row[1] for row in rows} + + def unmatched_words(self, channel: str) -> Set[str]: + game_id = self.current_game(channel) + with sqlite3.connect(self.path) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT word + FROM word + WHERE word.game = :game_id + AND id NOT IN (SELECT score.word FROM score WHERE game = :game_id) + """, + {"game_id": game_id}, + ) + rows = cur.fetchall() + return {word[0] for word in rows} + + +class Wordbot(Plugin): + def __init__(self, *args, **kwargs): + super(Wordbot, self).__init__(*args, **kwargs) + self.db_path = Path( + self.plugin_config.get("db_path", "data/wordbot/wordbot.db") + ) + self.words_path = Path(self.plugin_config.get("words_path", "data/words.txt")) + self.db = Db(self.db_path) + self.duration = int(self.plugin_config.get("hours_per_round", 5)) * 3600 + self.words_per_round = int(self.plugin_config.get("words_per_round", 300)) + self.__watch_games_task = None + self.__db_lock = asyncio.Lock() + + def get_words(self) -> Set[str]: + with open(self.words_path) as fp: + return {word.strip().lower() for word in fp} + + async def on_load(self): + # Make sure games are running on all channels + # This happens before on_connect + for channel in self.channels: + if not self.db.is_game_active(channel): + self.start_round(channel) + + async def on_connect(self, conn: IrcProtocol): + # Start watcher up to end games + self.__watch_games_task = asyncio.create_task(self.__watch_games(conn)) + + async def __watch_games(self, conn: IrcProtocol): + while True: + await asyncio.sleep(1.0) + for channel in self.bot.joined_channels: + if not self.db.is_game_active(channel): + async with self.__db_lock: + # End round + self.end_round(conn, channel) + # Create new round + self.start_round(channel) + + async def on_unload(self, conn: IrcProtocol): + if self.__watch_games_task: + self.__watch_games_task.cancel() + + async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str): + if who.nick == self.server_config.nick: + return + line = line.strip() + if not line: + return + elif line[0] == "!": + await self.handle_command(conn, channel, who, line) + else: + async with self.__db_lock: + if not self.db.is_game_active(channel): + # Don't try to score words for inactive games + return + parts = {word.strip().lower() for word in line.split()} + matches = parts & self.db.unmatched_words(channel) + for word in matches: + self.send_to( + conn, channel, f"Congrats! '{word}' is good for 1 point." + ) + self.db.add_score(channel, who.nick, word, line) + + async def handle_command( + self, conn: IrcProtocol, channel: str, who: Prefix, line: str + ): + parts = line.strip().split() + match parts: + case ["!wordbot", "end_now"]: + async with self.__db_lock: + self.end_round(conn, channel) + self.start_round(channel, allow_early_end=True) + case _: + pass + + def start_round(self, channel: str, allow_early_end: bool = False): + # Choose words for new round + with open(self.words_path) as fp: + words = [word.strip() for word in fp] + random.shuffle(words) + words = words[: self.words_per_round] + log.debug("%s", words) + self.db.start_round(channel, self.duration, words, allow_early_end) + + def end_round(self, conn: IrcProtocol, channel: str): + # Sort the scores + scores = sorted(self.db.scores(channel).items(), key=lambda value: -value[1]) + # Add their ordering + rankings = { + score: rank + for rank, score in enumerate( + sorted( + set(map(lambda value: value[1], scores)), key=lambda value: -value + ) + ) + } + log.debug("%r", rankings) + log.debug("%r", scores) + self.send_to(conn, channel, "Game over. Here were the scores:") + for user, score in scores: + self.send_to(conn, channel, f"{rankings[score] + 1}. {user}. {score}") + + +PLUGIN_TYPE = Wordbot