diff --git a/lib/contrib/markov/chain.ex b/lib/contrib/markov/chain.ex index cd5175a..c54cf2a 100644 --- a/lib/contrib/markov/chain.ex +++ b/lib/contrib/markov/chain.ex @@ -2,7 +2,7 @@ defmodule Omnibot.Contrib.Markov.Chain do alias Omnibot.{Contrib.Markov.Chain, Util} @enforce_keys [:order] - defstruct order: 2, chain: [] + defstruct order: 2, chain: %{} def train(chain, line) when is_binary(line) do train(chain, line |> String.split(~r/\s+/)) @@ -20,21 +20,16 @@ defmodule Omnibot.Contrib.Markov.Chain do end) end - def add_weight(%Chain {chain: chain, order: order}, key, word, increment \\ 1) do + def add_weight(%Chain{chain: chain, order: order}, key, word, increment \\ 1) do if length(key) != order do raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") end - chain = case find_index(chain, key) do - # Insert weight - nil -> [{key, %{word => increment}} | chain] - # Update weight - index -> List.update_at( - chain, - index, - fn {key, mapping} -> {key, Map.update(mapping, word, increment, &(&1 + increment))} end - ) - end + # %{ + # ["word1", "word2"] => %{"target" => weight} + # } + chain = Map.update(chain, key, %{word => increment}, + fn weights -> Map.update(weights, word, increment, &(increment + &1)) end) %Chain{chain: chain, order: order} end diff --git a/test/contrib/markov/chain_test.exs b/test/contrib/markov/chain_test.exs index 67253a8..446c497 100644 --- a/test/contrib/markov/chain_test.exs +++ b/test/contrib/markov/chain_test.exs @@ -5,66 +5,66 @@ defmodule MarkovChainTest do test "chain train works correctly" do chain = %Chain {order: 2} |> Chain.train(~w(foo bar baz)) - assert chain.chain == [ - {["bar", "baz"], %{nil => 1}}, - {["foo", "bar"], %{"baz" => 1}}, - ] + assert chain.chain == %{ + ["bar", "baz"] => %{nil => 1}, + ["foo", "bar"] => %{"baz" => 1}, + } chain = chain |> Chain.train(~w(foo bar baz)) - assert chain.chain == [ - {["bar", "baz"], %{nil => 2}}, - {["foo", "bar"], %{"baz" => 2}}, - ] + assert chain.chain == %{ + ["bar", "baz"] => %{nil => 2}, + ["foo", "bar"] => %{"baz" => 2}, + } chain = chain |> Chain.train(~w(baz bar foo)) - assert chain.chain == [ - {["bar", "foo"], %{nil => 1}}, - {["baz", "bar"], %{"foo" => 1}}, - {["bar", "baz"], %{nil => 2}}, - {["foo", "bar"], %{"baz" => 2}}, - ] + assert chain.chain == %{ + ["bar", "foo"] => %{nil => 1}, + ["baz", "bar"] => %{"foo" => 1}, + ["bar", "baz"] => %{nil => 2}, + ["foo", "bar"] => %{"baz" => 2}, + } chain = chain |> Chain.train(~w(a b c)) - assert chain.chain == [ - {["b", "c"], %{nil => 1}}, - {["a", "b"], %{"c" => 1}}, - {["bar", "foo"], %{nil => 1}}, - {["baz", "bar"], %{"foo" => 1}}, - {["bar", "baz"], %{nil => 2}}, - {["foo", "bar"], %{"baz" => 2}}, - ] + assert chain.chain == %{ + ["b", "c"] => %{nil => 1}, + ["a", "b"] => %{"c" => 1}, + ["bar", "foo"] => %{nil => 1}, + ["baz", "bar"] => %{"foo" => 1}, + ["bar", "baz"] => %{nil => 2}, + ["foo", "bar"] => %{"baz" => 2}, + } end test "chain add_weight works correctly" do chain = %Chain {order: 2} |> Chain.add_weight(["foo", "bar"], "baz") - assert chain.chain == [ - {["foo", "bar"], %{"baz" => 1}} - ] + assert chain.chain == %{ + ["foo", "bar"] => %{"baz" => 1}, + } chain = chain |> Chain.add_weight(["foo", "bar"], "baz", 2) - assert chain.chain == [ - {["foo", "bar"], %{"baz" => 3}} - ] + assert chain.chain == %{ + ["foo", "bar"] => %{"baz" => 3}, + } chain = chain |> Chain.add_weight(["foo", "bar"], "qux") - assert chain.chain == [ - {["foo", "bar"], %{"baz" => 3, "qux" => 1}} - ] + assert chain.chain == %{ + ["foo", "bar"] => %{"baz" => 3, "qux" => 1}, + } chain = chain |> Chain.add_weight(["bar", "baz"], "qux") - assert chain.chain == [ - {["bar", "baz"], %{"qux" => 1}}, - {["foo", "bar"], %{"baz" => 3, "qux" => 1}}, - ] + assert chain.chain == %{ + ["bar", "baz"] => %{"qux" => 1}, + ["foo", "bar"] => %{"baz" => 3, "qux" => 1}, + } chain = chain |> Chain.add_weight(["bar", "baz"], nil) - assert chain.chain == [ - {["bar", "baz"], %{"qux" => 1, nil => 1}}, - {["foo", "bar"], %{"baz" => 3, "qux" => 1}}, - ] + assert chain.chain == %{ + ["bar", "baz"] => %{"qux" => 1, nil => 1}, + ["foo", "bar"] => %{"baz" => 3, "qux" => 1}, + } end end