diff --git a/lib/contrib/markov/chain_server.ex b/lib/contrib/markov/chain_server.ex index 8e82751..aa4ec40 100644 --- a/lib/contrib/markov/chain_server.ex +++ b/lib/contrib/markov/chain_server.ex @@ -6,15 +6,10 @@ defmodule Omnibot.Contrib.Markov.ChainServer do ## Client API def start_link(opts) do - {cfg, opts} = Keyword.pop(opts, :cfg) {channel, opts} = Keyword.pop(opts, :channel) {user, opts} = Keyword.pop(opts, :user) - chain = case load(channel, user) do - {:ok, chain} -> chain - {:error, _} -> %Markov.Chain{order: cfg[:order]} - end - GenServer.start_link(__MODULE__, {chain, channel, user}, opts) + GenServer.start_link(__MODULE__, {channel, user}, opts) end @compile :inline @@ -23,7 +18,7 @@ defmodule Omnibot.Contrib.Markov.ChainServer do @compile :inline def channel_dir(channel), do: Path.join(Markov.save_dir(), channel) - def load(channel, user) do + def load(channel, user) when user != :all do with {:ok, contents} <- user_path(channel, user) |> File.read(), do: {:ok, :erlang.binary_to_term(contents)} end @@ -48,19 +43,49 @@ defmodule Omnibot.Contrib.Markov.ChainServer do GenServer.call(server, :user) end + def generate(server) do + GenServer.call(server, :generate) + end + ## Server callbacks + + @impl true + def init({channel, :all}) do + Logger.debug("Creating allchain for channel #{channel}") + + chain = File.ls!(channel_dir(channel)) + |> Enum.map(&(Path.join(channel_dir(channel), &1) |> Markov.Chain.load!())) + |> Markov.Chain.merge() + {:ok, {chain, channel, :all}} + # TODO: load allchain + #chain = case load(channel, user) do + #{:ok, chain} -> chain + #{:error, _} -> %Markov.Chain{order: cfg()[:order]} + #end + #{:ok, {chain, channel, user}} + end @impl true - def init({chain, channel, user}) do + def init({channel, user}) do + chain = case load(channel, user) do + {:ok, chain} -> chain + {:error, _} -> %Markov.Chain{order: Markov.cfg()[:order]} + end {:ok, {chain, channel, user}} end + @impl true + def handle_call(:save, _from, state = {_chain, channel, :all}) do + Logger.debug("Not saving :all chain for #{channel}") + {:reply, :ok, state} + end + @impl true def handle_call(:save, _from, state = {chain, channel, user}) do File.mkdir_p!(channel_dir(channel)) path = user_path(channel, user) Logger.debug("Saving chain for #{user} on #{channel} to #{path}") - File.write!(path, :erlang.term_to_binary(chain)) + :ok = Markov.Chain.save!(chain, path) {:reply, :ok, state} end @@ -83,4 +108,9 @@ defmodule Omnibot.Contrib.Markov.ChainServer do def handle_call(:user, _from, state = {_chain, _channel, user}) do {:reply, user, state} end + + @impl true + def handle_call(:generate, _from, state = {chain, _channel, _user}) do + {:reply, Markov.Chain.generate(chain), state} + end end diff --git a/lib/contrib/markov/markov.ex b/lib/contrib/markov/markov.ex index 58589da..a417121 100644 --- a/lib/contrib/markov/markov.ex +++ b/lib/contrib/markov/markov.ex @@ -3,7 +3,7 @@ defmodule Omnibot.Contrib.Markov do alias Omnibot.{Contrib.Markov.ChainServer, Util} require Logger - @default_config save_dir: "markov", order: 2, save_every: 5 * 60 + @default_config save_dir: "markov", order: 2, save_every: 5 * 60, ignore: [] @registry __MODULE__.Registry @supervisor __MODULE__.ChainSupervisor @@ -22,11 +22,18 @@ defmodule Omnibot.Contrib.Markov do end command "!markov", ["force"] do - Irc.send_to(irc, channel, "TODO") + reply = chain_server(channel, nick) |> ChainServer.generate() + Irc.send_to(irc, channel, "#{nick}: #{reply}") + end + + command "!markov", ["emulate", emulate] do + reply = chain_server(channel, emulate) |> ChainServer.generate() + Irc.send_to(irc, channel, "#{nick}: #{reply}") end command "!markov", ["all"] do - Irc.send_to(irc, channel, "TODO") + reply = chain_server(channel, :all) |> ChainServer.generate() + Irc.send_to(irc, channel, "#{nick}: #{reply}") end command "!markov", ["status"] do @@ -37,26 +44,12 @@ defmodule Omnibot.Contrib.Markov do cfg()[:save_dir] end - @impl true - def on_channel_msg(_irc, channel, nick, msg) do - train(channel, nick, msg) - end - def train(channel, user, msg) do - server = ensure_chain_server(channel, user) + server = chain_server(channel, user) ChainServer.train(server, msg) end - def ensure_chain(channel, user) do - ensure_chain_server(channel, user) - |> ChainServer.chain() - end - - def user_chain(channel, user) do - chain_server(channel, user) |> ChainServer.chain() - end - - def chain_server(:all) do + def chain_servers() do # See https://hexdocs.pm/elixir/Registry.html#select/2-examples to understand what the hell is going on here # (it just selects the PID of all chain_server processes) for {pid} <- Registry.select(@registry, [{{:_, :"$1", :_}, [], [{{:"$1"}}]}]), @@ -65,27 +58,20 @@ defmodule Omnibot.Contrib.Markov do def chain_server(channel, user) do case Registry.lookup(@registry, {channel, user}) do - [] -> nil + [] -> start_chain_server!(channel, user) [{pid, _} | _] -> pid end end - def ensure_chain_server(channel, user) do - case chain_server(channel, user) do - nil -> start_chain!(channel, user) - pid -> pid - end - end - - defp start_chain!(channel, user) do - {:ok, chain} = start_chain(channel, user) + defp start_chain_server!(channel, user) do + {:ok, chain} = start_chain_server(channel, user) chain end - defp start_chain(channel, user) do + defp start_chain_server(channel, user) do DynamicSupervisor.start_child( @supervisor, - {ChainServer, cfg: cfg(), channel: channel, user: user, name: {:via, Registry, {@registry, {channel, user}}}} + {ChainServer, channel: channel, user: user, name: {:via, Registry, {@registry, {channel, user}}}} ) end @@ -93,9 +79,19 @@ defmodule Omnibot.Contrib.Markov do start = Util.now_unix() Logger.debug("Saving markov chains") - chain_server(:all) |> Enum.each(&ChainServer.save/1) + chain_servers() |> Enum.each(&ChainServer.save/1) stop = Util.now_unix() Logger.info("Saved markov chains in #{stop - start} seconds") end + + @impl true + def on_channel_msg(_irc, channel, nick, msg) do + # self-messages are already ignored, so just check the configured ignore-list + filter = nick in cfg()[:ignore] + || (String.trim(msg) |> String.starts_with?("!")) + if !filter do + train(channel, nick, msg) + end + end end