import dataclasses from pathlib import Path from typing import IO, Any, Mapping, Sequence, Set import toml PluginConfig = Mapping[str, Any] class ConfigError(Exception): def __init__(self, which: str, hint: str | None, plugin: str | None = None): self.which = which self.hint = hint self.plugin = plugin msg = f"{self.which}" if self.hint: msg = f"{msg}: {self.hint}" if self.plugin: msg = f"in config for plugin {plugin}: {msg}" else: msg = f"in server config: {msg}" super(ConfigError, self).__init__(msg) @dataclasses.dataclass class ServerConfig: server: str = "" use_ssl: bool = False port: int = 6667 plugins: Sequence[PluginConfig] = dataclasses.field(default_factory=list) channels: Sequence[str] = dataclasses.field(default_factory=list) nick: str = "omnibot" loglevel: str | None = None def load(self, fp: IO[str]): obj = toml.load(fp) if "server" not in obj: raise ConfigError("server", "must be present") if not isinstance(obj["server"], str): raise ConfigError("server", "must be a string") self.server = obj["server"] if "use_ssl" in obj: if not isinstance(obj["use_ssl"], bool): raise ConfigError("use_ssl", "must be a boolean") self.use_ssl = obj["use_ssl"] else: # Don't use SSL by default self.use_ssl = False if "port" in obj: if not isinstance(obj["port"], int): raise ConfigError("port", "must be an integer") if not (0 < obj["port"] <= 65535): raise ConfigError("port", "must be between 0 and 65535") self.port = obj["port"] else: if self.use_ssl: self.port = 6697 else: self.port = 6667 if "plugins" in obj: if not isinstance(obj["plugins"], Sequence): raise ConfigError( "plugins", "must be a mapping of configuration values" ) self.plugins = obj["plugins"] if "channels" in obj: if not isinstance(obj["channels"], Sequence): raise ConfigError("channels", "must by a list of strings") self.channels = obj["channels"] if "nick" in obj: if not isinstance(obj["nick"], str): raise ConfigError("nick", "must be a string") self.nick = obj["nick"] if "loglevel" in obj: if not isinstance(obj["loglevel"], str): raise ConfigError("loglevel", "must be a string") loglevels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] if obj["loglevel"] not in loglevels: raise ConfigError("loglevel", "must be one of: " + " ".join(loglevels)) self.loglevel = obj["loglevel"] @property def all_channels(self) -> Set[str]: channels = set(self.channels) for plugin in self.plugins: if "channels" in plugin: channels |= set(plugin["channels"]) return channels