diff --git a/lib/contrib/markov/chain.ex b/lib/contrib/markov/chain.ex index 2e95a8a..95776b3 100644 --- a/lib/contrib/markov/chain.ex +++ b/lib/contrib/markov/chain.ex @@ -4,6 +4,10 @@ defmodule Omnibot.Contrib.Markov.Chain do @enforce_keys [:order] defstruct order: 2, chain: [] + def train(chain, line) when is_binary(line) do + train(chain, line |> String.split(~r/\s+/)) + end + def train(chain, words) when is_list(words) do order = chain.order @@ -12,22 +16,12 @@ defmodule Omnibot.Contrib.Markov.Chain do |> Enum.reduce(chain, &case Enum.split(&1, order) do {words, []} -> if length(words) == order, # Null case for the chain; this is an "end" state - do: add_weight(&2, words, nil) - # else: TODO ? train [a, nil] -> b ? - {words, [next]} -> - add_weight(&2, words, next) - end - ) + do: add_weight(&2, words, nil), + else: &2 # TODO ? train [a, nil] -> b ? + {words, [next]} -> add_weight(&2, words, next) + end) end - #def lookup(%Chain {chain: chain, order: order}, key) do - # if length(key) != order, do: raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") - # case Util.binary_search(chain, key) do - # {_index, value} -> value[word] - # nil -> nil - # end - #end - def add_weight(%Chain {chain: chain, order: order}, key, word, increment \\ 1) do if length(key) != order, do: raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") chain = case Enum.find_index(chain, fn {listkey, _} -> listkey == key end) do diff --git a/lib/contrib/markov/markov.ex b/lib/contrib/markov/markov.ex index f75ebfc..eb5f4dd 100644 --- a/lib/contrib/markov/markov.ex +++ b/lib/contrib/markov/markov.ex @@ -3,20 +3,44 @@ defmodule Omnibot.Contrib.Markov do alias Omnibot.Contrib.Markov.Chain - @default_config path: :"wordbot.ets", order: 2 + @default_config path: "markov", order: 2 @impl true def on_init(cfg) do # Create the markov database - path = if is_atom(cfg[:path]), - do: cfg[:path], - else: String.to_atom(cfg[:path]) - {:ok, db} = :dets.open_file(path) - db + path = String.to_atom(cfg[:path]) + :ets.new(path, [:public]) end @impl true - def on_channel_msg(_irc, _channel, _nick, msg) do - _words = String.split(msg, ~r/\s+/) + def on_channel_msg(_irc, channel, nick, msg) do + train(channel, nick, msg) + end + + def train(channel, user, msg) do + chain = (user_chain(channel, user) || create_user_chain(channel, user)) + |> Chain.train(msg) + true = update_user_chain(channel, user, chain) + end + + def user_chain(channel, user) do + db = state() + case :ets.lookup(db, {channel, user}) do + [] -> nil + [{{^channel, ^user}, chains}] -> chains + end + end + + def update_user_chain(channel, user, chain) do + db = state() + case user_chain(channel, user) do + nil -> :ets.insert_new(db, {{channel, user}, chain}) + chain -> :ets.insert(db, {{channel, user}, chain}) + end + end + + defp create_user_chain(channel, user) do + true = update_user_chain(channel, user, %Chain{order: cfg()[:order]}) + user_chain(channel, user) end end diff --git a/lib/supervisor.ex b/lib/supervisor.ex index f94255f..5df4311 100644 --- a/lib/supervisor.ex +++ b/lib/supervisor.ex @@ -10,7 +10,9 @@ defmodule Omnibot.Supervisor do @impl true def init(:ok) do - {_, bindings} = Code.eval_file("omnibot.exs") + + {_, bindings} = System.get_env("OMNIBOT_CFG", "omnibot.exs") + |> Code.eval_file() cfg = bindings[:config] children = [ diff --git a/test/contrib/markov/chain_test.exs b/test/contrib/markov/chain_test.exs index 630ade5..89287b1 100644 --- a/test/contrib/markov/chain_test.exs +++ b/test/contrib/markov/chain_test.exs @@ -25,7 +25,6 @@ defmodule MarkovChainTest do ] chain = chain |> Chain.add_weight(["foo", "bar"], "qux") - assert chain.chain == [ {["foo", "bar"], %{"baz" => 3, "qux" => 1}} ]