130 lines
4.7 KiB
Python
130 lines
4.7 KiB
Python
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
from typing import Sequence, Set
|
||
|
|
|
||
|
|
from asyncirc.protocol import IrcProtocol
|
||
|
|
from asyncirc.server import Server
|
||
|
|
from irclib.parser import Message
|
||
|
|
|
||
|
|
from .config import ServerConfig
|
||
|
|
from . import plugin
|
||
|
|
|
||
|
|
|
||
|
|
log = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class Bot:
|
||
|
|
def __init__(self, server_config: ServerConfig):
|
||
|
|
self.__server_config = server_config
|
||
|
|
self.__plugins = [
|
||
|
|
plugin.load_plugin(server_config, config)
|
||
|
|
for config in server_config.plugins
|
||
|
|
]
|
||
|
|
# TODO - this may not be needed
|
||
|
|
self.__channels: Set[str] = set()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def server_config(self) -> ServerConfig:
|
||
|
|
return self.__server_config
|
||
|
|
|
||
|
|
@property
|
||
|
|
def plugins(self) -> Sequence[plugin.Plugin]:
|
||
|
|
return self.__plugins
|
||
|
|
|
||
|
|
def channel_plugins(self, channel: str) -> Sequence[plugin.Plugin]:
|
||
|
|
return [plugin for plugin in self.plugins if channel in plugin.channels]
|
||
|
|
|
||
|
|
async def run(self):
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
server = Server(
|
||
|
|
self.server_config.server,
|
||
|
|
self.server_config.port,
|
||
|
|
self.server_config.use_ssl,
|
||
|
|
)
|
||
|
|
self.connection = IrcProtocol([server], self.server_config.nick, loop=loop)
|
||
|
|
# Register events
|
||
|
|
# self.connection.register("*", self.on_message)
|
||
|
|
self.connection.register("001", self.on_connect)
|
||
|
|
self.connection.register("JOIN", self.on_join)
|
||
|
|
self.connection.register("PART", self.on_part)
|
||
|
|
self.connection.register("KICK", self.on_kick)
|
||
|
|
self.connection.register("PRIVMSG", self.on_message)
|
||
|
|
# Connect
|
||
|
|
log.info("Connecting to %s", self.server_config.server)
|
||
|
|
await self.connection.connect()
|
||
|
|
# Keepalive loop
|
||
|
|
await self.keepalive()
|
||
|
|
|
||
|
|
async def on_connect(self, conn: IrcProtocol, message: Message):
|
||
|
|
# Join rooms
|
||
|
|
for ch in self.server_config.all_channels:
|
||
|
|
msg = Message(None, None, "JOIN", ch)
|
||
|
|
conn.send(str(msg))
|
||
|
|
|
||
|
|
async def on_join(self, conn: IrcProtocol, message: Message):
|
||
|
|
channel = message.parameters[0]
|
||
|
|
who = message.prefix
|
||
|
|
if who.nick == self.server_config.nick:
|
||
|
|
self.__channels |= {channel}
|
||
|
|
if channel not in self.server_config.all_channels:
|
||
|
|
# Try to leave this channel that we were forced to join like some kind of dog
|
||
|
|
msg = Message(None, None, "PART", channel)
|
||
|
|
conn.send(str(msg))
|
||
|
|
|
||
|
|
# Pass the message along to available plugins
|
||
|
|
plugins = self.channel_plugins(channel)
|
||
|
|
await asyncio.gather(
|
||
|
|
*[plugin.on_join(conn, channel, who) for plugin in plugins]
|
||
|
|
)
|
||
|
|
|
||
|
|
async def __on_part(self, conn: IrcProtocol, message: Message):
|
||
|
|
"This is the common logic between on_part and on_kick. Don't call this."
|
||
|
|
channel = message.parameters[0]
|
||
|
|
who = message.prefix
|
||
|
|
if who.nick == self.server_config.nick:
|
||
|
|
self.__channels -= {channel}
|
||
|
|
if channel not in self.server_config.all_channels:
|
||
|
|
# Try to rejoin this channel that we were force-parted from
|
||
|
|
msg = Message(None, None, "JOIN", channel)
|
||
|
|
conn.send(str(msg))
|
||
|
|
|
||
|
|
async def on_part(self, conn: IrcProtocol, message: Message):
|
||
|
|
await self.__on_part(conn, message)
|
||
|
|
# Pass the message along to available plugins
|
||
|
|
channel = message.parameters[0]
|
||
|
|
who = message.prefix
|
||
|
|
plugins = self.channel_plugins(channel)
|
||
|
|
await asyncio.gather(
|
||
|
|
*[plugin.on_part(conn, channel, who) for plugin in plugins]
|
||
|
|
)
|
||
|
|
|
||
|
|
async def on_kick(self, conn: IrcProtocol, message: Message):
|
||
|
|
await self.__on_part(conn, message)
|
||
|
|
# Pass the message along to available plugins
|
||
|
|
channel = message.parameters[0]
|
||
|
|
who = message.prefix
|
||
|
|
plugins = self.channel_plugins(channel)
|
||
|
|
await asyncio.gather(
|
||
|
|
*[plugin.on_kick(conn, channel, who) for plugin in plugins]
|
||
|
|
)
|
||
|
|
|
||
|
|
async def on_message(self, conn: IrcProtocol, message: Message):
|
||
|
|
# Pass the message to the plugins
|
||
|
|
log.debug("%s", message)
|
||
|
|
channel = message.parameters[0]
|
||
|
|
who = message.prefix
|
||
|
|
if who.nick == self.server_config.nick:
|
||
|
|
# Don't raise on_message events for ourselves.
|
||
|
|
return
|
||
|
|
line = message.parameters[1]
|
||
|
|
plugins = self.channel_plugins(channel)
|
||
|
|
await asyncio.gather(
|
||
|
|
*[plugin.on_message(conn, channel, who, line) for plugin in plugins]
|
||
|
|
)
|
||
|
|
|
||
|
|
async def keepalive(self):
|
||
|
|
# loop while we're connected, check every second
|
||
|
|
log.info("Starting keepalive loop")
|
||
|
|
while self.connection.connected:
|
||
|
|
await asyncio.sleep(1.0)
|