diff --git a/lib/contrib/markov/chain_server.ex b/lib/contrib/markov/chain_server.ex new file mode 100644 index 0000000..8e82751 --- /dev/null +++ b/lib/contrib/markov/chain_server.ex @@ -0,0 +1,86 @@ +defmodule Omnibot.Contrib.Markov.ChainServer do + use GenServer + alias Omnibot.Contrib.Markov + require Logger + + ## Client API + + def start_link(opts) do + {cfg, opts} = Keyword.pop(opts, :cfg) + {channel, opts} = Keyword.pop(opts, :channel) + {user, opts} = Keyword.pop(opts, :user) + + chain = case load(channel, user) do + {:ok, chain} -> chain + {:error, _} -> %Markov.Chain{order: cfg[:order]} + end + GenServer.start_link(__MODULE__, {chain, channel, user}, opts) + end + + @compile :inline + def user_path(channel, user), do: Path.join(channel_dir(channel), "#{user}.chain") + + @compile :inline + def channel_dir(channel), do: Path.join(Markov.save_dir(), channel) + + def load(channel, user) do + with {:ok, contents} <- user_path(channel, user) |> File.read(), + do: {:ok, :erlang.binary_to_term(contents)} + end + + def save(server) do + GenServer.call(server, :save) + end + + def train(server, msg) do + GenServer.call(server, {:train, msg}) + end + + def chain(server) do + GenServer.call(server, :chain) + end + + def channel(server) do + GenServer.call(server, :channel) + end + + def user(server) do + GenServer.call(server, :user) + end + + ## Server callbacks + + @impl true + def init({chain, channel, user}) do + {:ok, {chain, channel, user}} + end + + @impl true + def handle_call(:save, _from, state = {chain, channel, user}) do + File.mkdir_p!(channel_dir(channel)) + path = user_path(channel, user) + Logger.debug("Saving chain for #{user} on #{channel} to #{path}") + File.write!(path, :erlang.term_to_binary(chain)) + {:reply, :ok, state} + end + + @impl true + def handle_call({:train, msg}, _from, {chain, channel, user}) do + {:reply, :ok, {Markov.Chain.train(chain, msg), channel, user}} + end + + @impl true + def handle_call(:chain, _from, state = {chain, _channel, _user}) do + {:reply, chain, state} + end + + @impl true + def handle_call(:channel, _from, state = {_chain, channel, _user}) do + {:reply, channel, state} + end + + @impl true + def handle_call(:user, _from, state = {_chain, _channel, user}) do + {:reply, user, state} + end +end diff --git a/lib/contrib/markov/markov.ex b/lib/contrib/markov/markov.ex index 08b99da..58589da 100644 --- a/lib/contrib/markov/markov.ex +++ b/lib/contrib/markov/markov.ex @@ -1,12 +1,27 @@ defmodule Omnibot.Contrib.Markov do use Omnibot.Plugin - alias Omnibot.{Contrib.Markov.Chain, Util} + alias Omnibot.{Contrib.Markov.ChainServer, Util} require Logger - @default_config path: "markov.ets", order: 2, save_every: 5 * 60 + @default_config save_dir: "markov", order: 2, save_every: 5 * 60 + + @registry __MODULE__.Registry + @supervisor __MODULE__.ChainSupervisor + + @impl true + def children(cfg) do + [ + {Task, fn -> Stream.timer(cfg[:save_every] * 1000) + |> Stream.cycle() + |> Stream.each(fn _ -> save_chains() end) + |> Stream.run() + end}, + {Registry, keys: :unique, name: @registry}, + {DynamicSupervisor, name: @supervisor, strategy: :one_for_one}, + ] + end command "!markov", ["force"] do - # Choose a random value from the sender Irc.send_to(irc, channel, "TODO") end @@ -18,27 +33,8 @@ defmodule Omnibot.Contrib.Markov do Irc.send_to(irc, channel, "TODO") end - @impl true - def children(cfg) do - [ - {Task, fn -> - Stream.timer(cfg[:save_every] * 1000) - |> Stream.cycle() - |> Stream.each(fn _ -> save_chains() end) - |> Stream.run() - end} - ] - end - - @impl true - def on_init(_cfg) do - # Create the markov database - path = String.to_atom(cfg()[:path]) - {:ok, db} = :dets.open_file(path, []) - chains = :ets.new(:markov_chains, [:named_table, :public]) - :dets.to_ets(db, chains) - :ok = :dets.close(db) - chains + def save_dir() do + cfg()[:save_dir] end @impl true @@ -47,39 +43,57 @@ defmodule Omnibot.Contrib.Markov do end def train(channel, user, msg) do - chain = (user_chain(channel, user) || create_user_chain(channel, user)) - |> Chain.train(msg) - true = update_user_chain(channel, user, chain) + server = ensure_chain_server(channel, user) + ChainServer.train(server, msg) + end + + def ensure_chain(channel, user) do + ensure_chain_server(channel, user) + |> ChainServer.chain() end def user_chain(channel, user) do - db = state() - case :ets.lookup(db, {channel, user}) do + chain_server(channel, user) |> ChainServer.chain() + end + + def chain_server(:all) do + # See https://hexdocs.pm/elixir/Registry.html#select/2-examples to understand what the hell is going on here + # (it just selects the PID of all chain_server processes) + for {pid} <- Registry.select(@registry, [{{:_, :"$1", :_}, [], [{{:"$1"}}]}]), + do: pid + end + + def chain_server(channel, user) do + case Registry.lookup(@registry, {channel, user}) do [] -> nil - [{{^channel, ^user}, chains}] -> chains + [{pid, _} | _] -> pid end end - def update_user_chain(channel, user, chain) do - db = state() - case user_chain(channel, user) do - nil -> :ets.insert_new(db, {{channel, user}, chain}) - _old_chain -> :ets.insert(db, {{channel, user}, chain}) + def ensure_chain_server(channel, user) do + case chain_server(channel, user) do + nil -> start_chain!(channel, user) + pid -> pid end end - defp create_user_chain(channel, user) do - true = update_user_chain(channel, user, %Chain{order: cfg()[:order]}) - user_chain(channel, user) + defp start_chain!(channel, user) do + {:ok, chain} = start_chain(channel, user) + chain + end + + defp start_chain(channel, user) do + DynamicSupervisor.start_child( + @supervisor, + {ChainServer, cfg: cfg(), channel: channel, user: user, name: {:via, Registry, {@registry, {channel, user}}}} + ) end def save_chains() do start = Util.now_unix() Logger.debug("Saving markov chains") - {:ok, db} = :dets.open_file(cfg()[:path], []) - :ets.to_dets(state(), db) - :ok = :dets.close(db) + chain_server(:all) |> Enum.each(&ChainServer.save/1) stop = Util.now_unix() Logger.info("Saved markov chains in #{stop - start} seconds")