Initial commit with functional framework(!) and example plugin
Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
1
omnibot/__init__.py
Normal file
1
omnibot/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from . import bot
|
||||
24
omnibot/__main__.py
Normal file
24
omnibot/__main__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from .config import ServerConfig
|
||||
from .bot import Bot
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)-12s - %(levelname)-8s - %(message)s",
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
log.debug("Loading config")
|
||||
config = ServerConfig()
|
||||
config.load("config.toml")
|
||||
log.debug("Using configuration: %s", config)
|
||||
|
||||
server = Bot(config)
|
||||
await server.run()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
BIN
omnibot/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
omnibot/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
omnibot/__pycache__/__main__.cpython-310.pyc
Normal file
BIN
omnibot/__pycache__/__main__.cpython-310.pyc
Normal file
Binary file not shown.
129
omnibot/bot.py
Normal file
129
omnibot/bot.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Sequence, Set
|
||||
|
||||
from asyncirc.protocol import IrcProtocol
|
||||
from asyncirc.server import Server
|
||||
from irclib.parser import Message
|
||||
|
||||
from .config import ServerConfig
|
||||
from . import plugin
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(self, server_config: ServerConfig):
|
||||
self.__server_config = server_config
|
||||
self.__plugins = [
|
||||
plugin.load_plugin(server_config, config)
|
||||
for config in server_config.plugins
|
||||
]
|
||||
# TODO - this may not be needed
|
||||
self.__channels: Set[str] = set()
|
||||
|
||||
@property
|
||||
def server_config(self) -> ServerConfig:
|
||||
return self.__server_config
|
||||
|
||||
@property
|
||||
def plugins(self) -> Sequence[plugin.Plugin]:
|
||||
return self.__plugins
|
||||
|
||||
def channel_plugins(self, channel: str) -> Sequence[plugin.Plugin]:
|
||||
return [plugin for plugin in self.plugins if channel in plugin.channels]
|
||||
|
||||
async def run(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
server = Server(
|
||||
self.server_config.server,
|
||||
self.server_config.port,
|
||||
self.server_config.use_ssl,
|
||||
)
|
||||
self.connection = IrcProtocol([server], self.server_config.nick, loop=loop)
|
||||
# Register events
|
||||
# self.connection.register("*", self.on_message)
|
||||
self.connection.register("001", self.on_connect)
|
||||
self.connection.register("JOIN", self.on_join)
|
||||
self.connection.register("PART", self.on_part)
|
||||
self.connection.register("KICK", self.on_kick)
|
||||
self.connection.register("PRIVMSG", self.on_message)
|
||||
# Connect
|
||||
log.info("Connecting to %s", self.server_config.server)
|
||||
await self.connection.connect()
|
||||
# Keepalive loop
|
||||
await self.keepalive()
|
||||
|
||||
async def on_connect(self, conn: IrcProtocol, message: Message):
|
||||
# Join rooms
|
||||
for ch in self.server_config.all_channels:
|
||||
msg = Message(None, None, "JOIN", ch)
|
||||
conn.send(str(msg))
|
||||
|
||||
async def on_join(self, conn: IrcProtocol, message: Message):
|
||||
channel = message.parameters[0]
|
||||
who = message.prefix
|
||||
if who.nick == self.server_config.nick:
|
||||
self.__channels |= {channel}
|
||||
if channel not in self.server_config.all_channels:
|
||||
# Try to leave this channel that we were forced to join like some kind of dog
|
||||
msg = Message(None, None, "PART", channel)
|
||||
conn.send(str(msg))
|
||||
|
||||
# Pass the message along to available plugins
|
||||
plugins = self.channel_plugins(channel)
|
||||
await asyncio.gather(
|
||||
*[plugin.on_join(conn, channel, who) for plugin in plugins]
|
||||
)
|
||||
|
||||
async def __on_part(self, conn: IrcProtocol, message: Message):
|
||||
"This is the common logic between on_part and on_kick. Don't call this."
|
||||
channel = message.parameters[0]
|
||||
who = message.prefix
|
||||
if who.nick == self.server_config.nick:
|
||||
self.__channels -= {channel}
|
||||
if channel not in self.server_config.all_channels:
|
||||
# Try to rejoin this channel that we were force-parted from
|
||||
msg = Message(None, None, "JOIN", channel)
|
||||
conn.send(str(msg))
|
||||
|
||||
async def on_part(self, conn: IrcProtocol, message: Message):
|
||||
await self.__on_part(conn, message)
|
||||
# Pass the message along to available plugins
|
||||
channel = message.parameters[0]
|
||||
who = message.prefix
|
||||
plugins = self.channel_plugins(channel)
|
||||
await asyncio.gather(
|
||||
*[plugin.on_part(conn, channel, who) for plugin in plugins]
|
||||
)
|
||||
|
||||
async def on_kick(self, conn: IrcProtocol, message: Message):
|
||||
await self.__on_part(conn, message)
|
||||
# Pass the message along to available plugins
|
||||
channel = message.parameters[0]
|
||||
who = message.prefix
|
||||
plugins = self.channel_plugins(channel)
|
||||
await asyncio.gather(
|
||||
*[plugin.on_kick(conn, channel, who) for plugin in plugins]
|
||||
)
|
||||
|
||||
async def on_message(self, conn: IrcProtocol, message: Message):
|
||||
# Pass the message to the plugins
|
||||
log.debug("%s", message)
|
||||
channel = message.parameters[0]
|
||||
who = message.prefix
|
||||
if who.nick == self.server_config.nick:
|
||||
# Don't raise on_message events for ourselves.
|
||||
return
|
||||
line = message.parameters[1]
|
||||
plugins = self.channel_plugins(channel)
|
||||
await asyncio.gather(
|
||||
*[plugin.on_message(conn, channel, who, line) for plugin in plugins]
|
||||
)
|
||||
|
||||
async def keepalive(self):
|
||||
# loop while we're connected, check every second
|
||||
log.info("Starting keepalive loop")
|
||||
while self.connection.connected:
|
||||
await asyncio.sleep(1.0)
|
||||
85
omnibot/config.py
Normal file
85
omnibot/config.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence, Set
|
||||
|
||||
import toml
|
||||
|
||||
|
||||
PluginConfig = Mapping[str, Any]
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
def __init__(self, which: str, hint: str | None, plugin: str | None = None):
|
||||
self.which = which
|
||||
self.hint = hint
|
||||
self.plugin = plugin
|
||||
msg = f"{self.which}"
|
||||
if self.hint:
|
||||
msg = f"{msg}: {self.hint}"
|
||||
if self.plugin:
|
||||
msg = f"in config for plugin {plugin}: {msg}"
|
||||
else:
|
||||
msg = f"in server config: {msg}"
|
||||
super(ConfigError, self).__init__(msg)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerConfig:
|
||||
server: str = ""
|
||||
use_ssl: bool = False
|
||||
port: int = 6667
|
||||
plugins: Sequence[PluginConfig] = dataclasses.field(default_factory=list)
|
||||
channels: Sequence[str] = dataclasses.field(default_factory=list)
|
||||
nick: str = "omnibot"
|
||||
|
||||
def load(self, path: Path | str):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
with open(path) as fp:
|
||||
obj = toml.load(fp)
|
||||
|
||||
if "server" not in obj:
|
||||
raise ConfigError("server", "must be present")
|
||||
if not isinstance(obj["server"], str):
|
||||
raise ConfigError("server", "must be a string")
|
||||
self.server = obj["server"]
|
||||
|
||||
if "use_ssl" in obj:
|
||||
if not isinstance(obj["use_ssl"], bool):
|
||||
raise ConfigError("use_ssl", "must be a boolean")
|
||||
self.use_ssl = obj["use_ssl"]
|
||||
else:
|
||||
# Don't use SSL by default
|
||||
self.use_ssl = False
|
||||
|
||||
if "port" in obj:
|
||||
if not isinstance(obj["port"], int):
|
||||
raise ConfigError("port", "must be an integer")
|
||||
if not (0 < obj["port"] <= 65535):
|
||||
raise ConfigError("port", "must be between 0 and 65535")
|
||||
self.port = obj["port"]
|
||||
else:
|
||||
if self.use_ssl:
|
||||
self.port = 6697
|
||||
else:
|
||||
self.port = 6667
|
||||
|
||||
if "plugins" in obj:
|
||||
if not isinstance(obj["plugins"], Sequence):
|
||||
raise ConfigError(
|
||||
"plugins", "must be a mapping of configuration values"
|
||||
)
|
||||
self.plugins = obj["plugins"]
|
||||
|
||||
if "nick" in obj:
|
||||
if not isinstance(obj["nick"], str):
|
||||
raise ConfigError("nick", "must be a string")
|
||||
self.nick = obj["nick"]
|
||||
|
||||
@property
|
||||
def all_channels(self) -> Set[str]:
|
||||
channels = set(self.channels)
|
||||
for plugin in self.plugins:
|
||||
if "channels" in plugin:
|
||||
channels |= set(plugin["channels"])
|
||||
return channels
|
||||
60
omnibot/plugin.py
Normal file
60
omnibot/plugin.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from asyncirc.protocol import IrcProtocol
|
||||
from irclib.parser import Message, Prefix
|
||||
|
||||
from .config import PluginConfig, ServerConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Plugin:
|
||||
def __init__(self, server_config: ServerConfig, plugin_config: PluginConfig):
|
||||
self.__server_config = server_config
|
||||
self.__plugin_config = plugin_config
|
||||
|
||||
@property
|
||||
def channels(self) -> Sequence[str]:
|
||||
if "channels" in self.plugin_config:
|
||||
return self.plugin_config["channels"]
|
||||
else:
|
||||
return self.server_config.channels
|
||||
|
||||
@property
|
||||
def plugin_config(self) -> PluginConfig:
|
||||
return self.__plugin_config
|
||||
|
||||
@property
|
||||
def server_config(self) -> ServerConfig:
|
||||
return self.__server_config
|
||||
|
||||
@property
|
||||
def nick(self) -> str:
|
||||
return self.server_config.nick
|
||||
|
||||
def send_to(self, conn: IrcProtocol, who: str, message: str):
|
||||
message = Message(None, None, "PRIVMSG", who, message)
|
||||
conn.send(str(message))
|
||||
|
||||
async def on_join(self, conn: IrcProtocol, channel: str, who: Prefix):
|
||||
pass
|
||||
|
||||
async def on_part(self, conn: IrcProtocol, channel: str, who: Prefix):
|
||||
pass
|
||||
|
||||
async def on_kick(self, conn: IrcProtocol, channel: str, who: Prefix):
|
||||
pass
|
||||
|
||||
async def on_message(self, conn: IrcProtocol, channel: str, who: Prefix, line: str):
|
||||
pass
|
||||
|
||||
|
||||
def load_plugin(server_config: ServerConfig, plugin_config: PluginConfig) -> Plugin:
|
||||
name = plugin_config["module"]
|
||||
log.info("Loading plugin %s", name)
|
||||
plugin_module = importlib.import_module(name)
|
||||
PluginType = plugin_module.PLUGIN_TYPE
|
||||
return PluginType(server_config, plugin_config)
|
||||
0
omnibot/server.py
Normal file
0
omnibot/server.py
Normal file
Reference in New Issue
Block a user