From 061cf9ee7b1de283a800a3921e0ef2ca35bee36c Mon Sep 17 00:00:00 2001 From: Alek Ratzloff Date: Mon, 30 May 2022 18:14:48 -0700 Subject: [PATCH] Add get_message_types() to plugin API This allows plugins to specify the types of messages they handle. This will be used specifically for the nickserv plugin, but could be useful for other things too. Signed-off-by: Alek Ratzloff --- omnibot/bot.py | 16 +++++++++++----- omnibot/plugin.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/omnibot/bot.py b/omnibot/bot.py index a2baac3..675b4d3 100644 --- a/omnibot/bot.py +++ b/omnibot/bot.py @@ -61,7 +61,7 @@ class Bot: self.connection.register("JOIN", self.on_join) self.connection.register("PART", self.on_part) self.connection.register("KICK", self.on_kick) - self.connection.register("PRIVMSG", self.on_message) + self.connection.register("*", self.on_message) # Connect log.info("Connecting to %s", self.server_config.server) await self.connection.connect() @@ -135,10 +135,16 @@ class Bot: # Don't raise on_message events for ourselves. return line = message.parameters[1] - plugins = self.channel_plugins(channel) - await asyncio.gather( - *[plugin.on_message(conn, channel, who, line) for plugin in plugins] - ) + # Filter plugins by get_message_types() and channel + plugins = [ + plugin + for plugin in self.channel_plugins(channel) + if message.command in plugin.get_message_types() + ] + if plugins: + await asyncio.gather( + *[plugin.on_message(conn, channel, who, line) for plugin in plugins] + ) async def keepalive(self): # loop while we're connected, check every second diff --git a/omnibot/plugin.py b/omnibot/plugin.py index d644aa8..b515bd0 100644 --- a/omnibot/plugin.py +++ b/omnibot/plugin.py @@ -19,6 +19,16 @@ class Plugin: self.__plugin_config = plugin_config self.__bot = bot + def get_message_types(self) -> Sequence[str]: + """ + Gets the message types that this plugin listens for. + + This is usually going to be just PRIVMSG by itself, i.e. `["PRIVMSG"]`. + However, if you want to handle different messages (such as PART, JOIN, + KICK, etc) then you can change that here. + """ + return ["PRIVMSG"] + @property def channels(self) -> Sequence[str]: if "channels" in self.plugin_config: