Files
omnibot22/omnibot/bot.py
Alek Ratzloff f0cfe53c8e Add on_load for plugins
This asynchronous function is called on all plugins right before the IRC
connection is made.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
2022-05-26 19:06:48 -07:00

142 lines
5.1 KiB
Python

import asyncio
import logging
from typing import Sequence, Set
from asyncirc.protocol import IrcProtocol
from asyncirc.server import Server
from irclib.parser import Message
from .config import ServerConfig
from . import plugin
log = logging.getLogger(__name__)
class Bot:
def __init__(self, server_config: ServerConfig):
self.__server_config = server_config
self.__plugins = [
plugin.load_plugin(server_config, config)
for config in server_config.plugins
]
# TODO - this may not be needed
self.__channels: Set[str] = set()
self.__quitting = asyncio.Event()
@property
def server_config(self) -> ServerConfig:
return self.__server_config
@property
def plugins(self) -> Sequence[plugin.Plugin]:
return self.__plugins
def quit(self):
self.__quitting.set()
def channel_plugins(self, channel: str) -> Sequence[plugin.Plugin]:
return [plugin for plugin in self.plugins if channel in plugin.channels]
async def run(self):
loop = asyncio.get_running_loop()
server = Server(
self.server_config.server,
self.server_config.port,
self.server_config.use_ssl,
)
log.info("Initializing plugins")
await asyncio.gather(*[plugin.on_load() for plugin in self.plugins])
self.connection = IrcProtocol([server], self.server_config.nick, loop=loop)
# Register events
# self.connection.register("*", self.on_message)
self.connection.register("001", self.on_connect)
self.connection.register("JOIN", self.on_join)
self.connection.register("PART", self.on_part)
self.connection.register("KICK", self.on_kick)
self.connection.register("PRIVMSG", self.on_message)
# Connect
log.info("Connecting to %s", self.server_config.server)
await self.connection.connect()
# Keepalive loop
await self.keepalive()
async def on_connect(self, conn: IrcProtocol, message: Message):
# Join rooms
for ch in self.server_config.all_channels:
msg = Message(None, None, "JOIN", ch)
conn.send(str(msg))
async def on_join(self, conn: IrcProtocol, message: Message):
log.debug("%s", message)
channel = message.parameters[0]
who = message.prefix
if who.nick == self.server_config.nick:
self.__channels |= {channel}
if channel not in self.server_config.all_channels:
# Try to leave this channel that we were forced to join like some kind of dog
msg = Message(None, None, "PART", channel)
conn.send(str(msg))
# Pass the message along to available plugins
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_join(conn, channel, who) for plugin in plugins]
)
async def __on_part(self, conn: IrcProtocol, message: Message):
"This is the common logic between on_part and on_kick. Don't call this."
channel = message.parameters[0]
who = message.prefix
if who.nick == self.server_config.nick:
self.__channels -= {channel}
if channel not in self.server_config.all_channels:
# Try to rejoin this channel that we were force-parted from
msg = Message(None, None, "JOIN", channel)
conn.send(str(msg))
async def on_part(self, conn: IrcProtocol, message: Message):
log.debug("%s", message)
await self.__on_part(conn, message)
# Pass the message along to available plugins
channel = message.parameters[0]
who = message.prefix
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_part(conn, channel, who) for plugin in plugins]
)
async def on_kick(self, conn: IrcProtocol, message: Message):
log.debug("%s", message)
await self.__on_part(conn, message)
# Pass the message along to available plugins
channel = message.parameters[0]
who = message.prefix
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_kick(conn, channel, who) for plugin in plugins]
)
async def on_message(self, conn: IrcProtocol, message: Message):
# Pass the message to the plugins
log.debug("%s", message)
channel = message.parameters[0]
who = message.prefix
if who.nick == self.server_config.nick:
# Don't raise on_message events for ourselves.
return
line = message.parameters[1]
plugins = self.channel_plugins(channel)
await asyncio.gather(
*[plugin.on_message(conn, channel, who, line) for plugin in plugins]
)
async def keepalive(self):
# loop while we're connected, check every second
await self.__quitting.wait()
log.info("Shutting down gracefully")
# TODO: unload modules
await asyncio.gather(
*[plugin.on_unload(self.connection) for plugin in self.plugins]
)