86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
|
|
import dataclasses
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Mapping, Sequence, Set
|
||
|
|
|
||
|
|
import toml
|
||
|
|
|
||
|
|
|
||
|
|
PluginConfig = Mapping[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
class ConfigError(Exception):
|
||
|
|
def __init__(self, which: str, hint: str | None, plugin: str | None = None):
|
||
|
|
self.which = which
|
||
|
|
self.hint = hint
|
||
|
|
self.plugin = plugin
|
||
|
|
msg = f"{self.which}"
|
||
|
|
if self.hint:
|
||
|
|
msg = f"{msg}: {self.hint}"
|
||
|
|
if self.plugin:
|
||
|
|
msg = f"in config for plugin {plugin}: {msg}"
|
||
|
|
else:
|
||
|
|
msg = f"in server config: {msg}"
|
||
|
|
super(ConfigError, self).__init__(msg)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclasses.dataclass
|
||
|
|
class ServerConfig:
|
||
|
|
server: str = ""
|
||
|
|
use_ssl: bool = False
|
||
|
|
port: int = 6667
|
||
|
|
plugins: Sequence[PluginConfig] = dataclasses.field(default_factory=list)
|
||
|
|
channels: Sequence[str] = dataclasses.field(default_factory=list)
|
||
|
|
nick: str = "omnibot"
|
||
|
|
|
||
|
|
def load(self, path: Path | str):
|
||
|
|
if isinstance(path, str):
|
||
|
|
path = Path(path)
|
||
|
|
with open(path) as fp:
|
||
|
|
obj = toml.load(fp)
|
||
|
|
|
||
|
|
if "server" not in obj:
|
||
|
|
raise ConfigError("server", "must be present")
|
||
|
|
if not isinstance(obj["server"], str):
|
||
|
|
raise ConfigError("server", "must be a string")
|
||
|
|
self.server = obj["server"]
|
||
|
|
|
||
|
|
if "use_ssl" in obj:
|
||
|
|
if not isinstance(obj["use_ssl"], bool):
|
||
|
|
raise ConfigError("use_ssl", "must be a boolean")
|
||
|
|
self.use_ssl = obj["use_ssl"]
|
||
|
|
else:
|
||
|
|
# Don't use SSL by default
|
||
|
|
self.use_ssl = False
|
||
|
|
|
||
|
|
if "port" in obj:
|
||
|
|
if not isinstance(obj["port"], int):
|
||
|
|
raise ConfigError("port", "must be an integer")
|
||
|
|
if not (0 < obj["port"] <= 65535):
|
||
|
|
raise ConfigError("port", "must be between 0 and 65535")
|
||
|
|
self.port = obj["port"]
|
||
|
|
else:
|
||
|
|
if self.use_ssl:
|
||
|
|
self.port = 6697
|
||
|
|
else:
|
||
|
|
self.port = 6667
|
||
|
|
|
||
|
|
if "plugins" in obj:
|
||
|
|
if not isinstance(obj["plugins"], Sequence):
|
||
|
|
raise ConfigError(
|
||
|
|
"plugins", "must be a mapping of configuration values"
|
||
|
|
)
|
||
|
|
self.plugins = obj["plugins"]
|
||
|
|
|
||
|
|
if "nick" in obj:
|
||
|
|
if not isinstance(obj["nick"], str):
|
||
|
|
raise ConfigError("nick", "must be a string")
|
||
|
|
self.nick = obj["nick"]
|
||
|
|
|
||
|
|
@property
|
||
|
|
def all_channels(self) -> Set[str]:
|
||
|
|
channels = set(self.channels)
|
||
|
|
for plugin in self.plugins:
|
||
|
|
if "channels" in plugin:
|
||
|
|
channels |= set(plugin["channels"])
|
||
|
|
return channels
|