diff --git a/lib/contrib/markov/chain.ex b/lib/contrib/markov/chain.ex index 4c95a1d..2e95a8a 100644 --- a/lib/contrib/markov/chain.ex +++ b/lib/contrib/markov/chain.ex @@ -1,35 +1,45 @@ defmodule Omnibot.Contrib.Markov.Chain do - alias Omnibot.{Contrib.Markov.Chain, Util} + alias Omnibot.Contrib.Markov.Chain @enforce_keys [:order] defstruct order: 2, chain: [] - def train(%Chain {chain: chain, order: order}, words) when is_list(words) do + def train(chain, words) when is_list(words) do + order = chain.order Enum.filter(words, &(String.length(&1) > 0)) |> Enum.chunk_every(order + 1, 1) # this gives us a "sliding window" effect - |> Enum.reduce(chain, &case Enum.split(words, order) do - {words, []} -> if length(&1) == order, + |> Enum.reduce(chain, &case Enum.split(&1, order) do + {words, []} -> if length(words) == order, # Null case for the chain; this is an "end" state - do: train_one(%Chain {chain: &2, order: order}, words, nil) + do: add_weight(&2, words, nil) # else: TODO ? train [a, nil] -> b ? {words, [next]} -> - train_one(%Chain {chain: &2, order: order}, words, next) + add_weight(&2, words, next) end ) end - def train_one(%Chain {chain: _chain, order: _order}, _key, _value) do - end + #def lookup(%Chain {chain: chain, order: order}, key) do + # if length(key) != order, do: raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") + # case Util.binary_search(chain, key) do + # {_index, value} -> value[word] + # nil -> nil + # end + #end - def lookup(%Chain {chain: chain, order: order}, key) do + def add_weight(%Chain {chain: chain, order: order}, key, word, increment \\ 1) do if length(key) != order, do: raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") - case Util.binary_search(chain, key) do - {_index, value} -> value - nil -> nil + chain = case Enum.find_index(chain, fn {listkey, _} -> listkey == key end) do + # Insert weight + nil -> [{key, %{word => increment}} | chain] + # Update weight + index -> List.update_at( + chain, + index, + fn {key, mapping} -> {key, Map.update(mapping, word, increment, &(&1 + increment))} end + ) end - end - - def put(%Chain {chain: _chain, order: _order}, _key, _value) do + %Chain{chain: chain, order: order} end end diff --git a/test/contrib/markov/chain_test.exs b/test/contrib/markov/chain_test.exs index e0de4a9..630ade5 100644 --- a/test/contrib/markov/chain_test.exs +++ b/test/contrib/markov/chain_test.exs @@ -2,11 +2,32 @@ defmodule MarkovChainTest do use ExUnit.Case alias Omnibot.Contrib.Markov.Chain - test "chain train_one works correctly" do + test "chain train works correctly" do chain = %Chain {order: 2} - |> Chain.train_one(["foo", "bar"], "baz") - #assert chain.chain == [ - #{["foo", "bar"], {"baz", 1}} - #] + |> Chain.train(~w(foo bar baz)) + assert chain.chain == [ + {["bar", "baz"], %{nil => 1}}, + {["foo", "bar"], %{"baz" => 1}}, + ] + end + + test "chain add_weight works correctly" do + chain = %Chain {order: 2} + |> Chain.add_weight(["foo", "bar"], "baz") + assert chain.chain == [ + {["foo", "bar"], %{"baz" => 1}} + ] + + chain = chain |> Chain.add_weight(["foo", "bar"], "baz", 2) + + assert chain.chain == [ + {["foo", "bar"], %{"baz" => 3}} + ] + + chain = chain |> Chain.add_weight(["foo", "bar"], "qux") + + assert chain.chain == [ + {["foo", "bar"], %{"baz" => 3, "qux" => 1}} + ] end end