Files
omnibot22/omnibot/bot.py

174 lines
6.4 KiB
Python
Raw Normal View History

import asyncio
import logging
from typing import Sequence, Set
from asyncirc.protocol import IrcProtocol
from asyncirc.server import Server
from irclib.parser import Message
from .config import ServerConfig
from . import plugin
log = logging.getLogger(__name__)
class Bot:
def __init__(self, server_config: ServerConfig):
self.__server_config = server_config
self.__channels: Set[str] = set()
self.__quitting = asyncio.Event()
self.__plugins = [
plugin.load_plugin(self, config)
for config in server_config.plugins
if config.get("enabled", True)
]
@property
def server_config(self) -> ServerConfig:
return self.__server_config
@property
def plugins(self) -> Sequence[plugin.Plugin]:
return self.__plugins
@property
def joined_channels(self) -> Set[str]:
"""
Returns a list of all channels that this bot has joined.
"""
return self.__channels
def quit(self):
self.__quitting.set()
def channel_plugins(self, channel: str) -> Sequence[plugin.Plugin]:
return [plugin for plugin in self.plugins if channel in plugin.channels]
async def run(self):
loop = asyncio.get_running_loop()
server = Server(
self.server_config.server,
self.server_config.port,
self.server_config.use_ssl,
)
log.info("Initializing plugins")
await asyncio.gather(*[plugin.on_load() for plugin in self.plugins])
self.connection = IrcProtocol([server], self.server_config.nick, loop=loop)
# Register events
# self.connection.register("*", self.on_message)
self.connection.register("001", self.on_connect)
self.connection.register("JOIN", self.on_join)
self.connection.register("PART", self.on_part)
self.connection.register("KICK", self.on_kick)
self.connection.register("*", self.on_message)
# Connect
log.info("Connecting to %s", self.server_config.server)
await self.connection.connect()
# Keepalive loop
await self.keepalive()
async def on_connect(self, conn: IrcProtocol, message: Message):
# Join rooms
for ch in self.server_config.all_channels:
msg = Message(None, None, "JOIN", ch)
conn.send(str(msg))
# on_connect event on all plugins
await asyncio.gather(*[plugin.on_connect(conn) for plugin in self.plugins])
async def on_join(self, conn: IrcProtocol, message: Message):
log.debug("%s", message)
channel = message.parameters[0]
who = message.prefix
if who.nick == self.server_config.nick:
self.__channels |= {channel}
if channel not in self.server_config.all_channels:
# Try to leave this channel that we were forced to join like some kind of dog
msg = Message(None, None, "PART", channel)
conn.send(str(msg))
# Pass the message along to available plugins
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_join(conn, channel, who) for plugin in plugins]
)
async def __on_part(self, conn: IrcProtocol, message: Message):
"This is the common logic between on_part and on_kick. Don't call this."
channel = message.parameters[0]
who = message.prefix
if who.nick == self.server_config.nick:
self.__channels -= {channel}
if channel not in self.server_config.all_channels:
# Try to rejoin this channel that we were force-parted from
msg = Message(None, None, "JOIN", channel)
conn.send(str(msg))
async def on_part(self, conn: IrcProtocol, message: Message):
log.debug("%s", message)
await self.__on_part(conn, message)
# Pass the message along to available plugins
channel = message.parameters[0]
who = message.prefix
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_part(conn, channel, who) for plugin in plugins]
)
async def on_kick(self, conn: IrcProtocol, message: Message):
log.debug("%s", message)
await self.__on_part(conn, message)
# Pass the message along to available plugins
channel = message.parameters[0]
who = message.prefix
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_kick(conn, channel, who) for plugin in plugins]
)
async def on_message(self, conn: IrcProtocol, message: Message):
# Pass the message to the plugins
log.debug("%s", message)
channel = message.parameters[0]
who = message.prefix
if who.nick == self.server_config.nick:
# Don't raise on_message events for ourselves.
return
line = message.parameters[1]
# TL;DR OF THE BELOW: if the first parameter looks like a channel in
# addition to message type, then filter by channel. Otherwise, don't
# filter by channel.
#
# Here's the issue: plugins are *usually* multiplexed by channel. But
# that's only for messages that target channels, such as PRIVMSG and
# JOIN. For non-channel messages, such as server status messages (such
# as 001 on connect, or 372 for MOTD, etc) we want to ignore the channel
# aspect of plugin multiplexing.
# In order to accomplish this, we just check if the first parameter
# looks like a channel - i.e., starts with an octothorpe #.
if channel and channel[0] == "#":
plugin_pool = self.channel_plugins(channel)
else:
plugin_pool = self.plugins
# Filter plugins by get_message_types()
plugins = [
plugin
for plugin in plugin_pool
if message.command in plugin.get_message_types()
]
if plugins:
await asyncio.gather(
*[plugin.on_message(conn, channel, who, line) for plugin in plugins]
)
async def keepalive(self):
# loop while we're connected, check every second
await self.__quitting.wait()
log.info("Shutting down gracefully")
# TODO: unload modules
await asyncio.gather(
*[plugin.on_unload(self.connection) for plugin in self.plugins]
)