diff --git a/lib/contrib/markov/chain.ex b/lib/contrib/markov/chain.ex index c54cf2a..895f241 100644 --- a/lib/contrib/markov/chain.ex +++ b/lib/contrib/markov/chain.ex @@ -1,5 +1,6 @@ defmodule Omnibot.Contrib.Markov.Chain do alias Omnibot.{Contrib.Markov.Chain, Util} + require Logger @enforce_keys [:order] defstruct order: 2, chain: %{} @@ -33,22 +34,6 @@ defmodule Omnibot.Contrib.Markov.Chain do %Chain{chain: chain, order: order} end - def get(chain, key) when is_list(chain) do - item = Enum.find(chain, fn {listkey, _} -> listkey == key end) - case item do - nil -> nil - {_, weights} -> weights - end - end - - def get(chain, key), do: get(chain.chain, key) - - def find_index(chain, key) when is_list(chain) do - Enum.find_index(chain, fn {listkey, _} -> listkey == key end) - end - - def find_index(chain = %Chain{}, key), do: find_index(chain.chain, key) - def generate(chain) do {seed, _} = Stream.filter(chain.chain, fn {key, _} -> length(key) == chain.order end) |> Enum.random() @@ -59,10 +44,44 @@ defmodule Omnibot.Contrib.Markov.Chain do do_generate(chain, key) |> Enum.join(" ") end + def load!(path) do + {:ok, chain} = load(path) + chain + end + + def load(path) do + Logger.debug("Loading markov chain #{path}") + with {:ok, contents} <- File.read(path), + do: {:ok, :erlang.binary_to_term(contents)} + end + + def save!(chain, path) do + :ok = save(chain, path) + end + + def save(chain, path) do + File.write!(path, :erlang.term_to_binary(chain)) + end + + def merge(lhs, rhs) do + if lhs.order != rhs.order do + raise(ArgumentError, message: "markov chain orders must match (#{lhs.order} vs #{rhs.order})") + end + + merged = Map.merge(lhs.chain, rhs.chain, + fn _k, lhs, rhs -> Map.merge(lhs, rhs, fn _k, w1, w2-> w1 + w2 end) end + ) + %Chain{order: lhs.order, chain: merged} + end + + def merge([chain]), do: chain + + def merge([chain | tail]), do: merge(tail) |> merge(chain) + defp do_generate(_chain, [nil | _]), do: [] defp do_generate(chain, key) do - weights = get(chain, key) || [] + weights = chain.chain[key] || %{} [next | key] = key ++ [Util.weighted_random(weights)] [next | do_generate(chain, key)] end diff --git a/test/contrib/markov/chain_test.exs b/test/contrib/markov/chain_test.exs index 446c497..78bdf01 100644 --- a/test/contrib/markov/chain_test.exs +++ b/test/contrib/markov/chain_test.exs @@ -67,4 +67,19 @@ defmodule MarkovChainTest do ["foo", "bar"] => %{"baz" => 3, "qux" => 1}, } end + + test "chain merge works correctly" do + chain1 = %Chain {order: 2} + |> Chain.add_weight(["foo", "bar"], "baz") + + chain2 = %Chain {order: 2} + |> Chain.add_weight(["foo", "bar"], "baz") + |> Chain.add_weight(["bar", "baz"], "qux") + + merged = Chain.merge(chain1, chain2) + assert merged.chain == %{ + ["foo", "bar"] => %{"baz" => 2}, + ["bar", "baz"] => %{"qux" => 1}, + } + end end