diff --git a/lib/contrib/markov/chain.ex b/lib/contrib/markov/chain.ex index e60dd6d..f7e6eda 100644 --- a/lib/contrib/markov/chain.ex +++ b/lib/contrib/markov/chain.ex @@ -21,17 +21,17 @@ defmodule Omnibot.Contrib.Markov.Chain do end) end - def add_weight(%Chain{chain: chain, order: order}, key, word, increment \\ 1) do - if length(key) != order do - raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") + def add_weight(chain = %Chain{}, key, word, increment \\ 1) do + if length(key) != chain.order do + raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{chain.order})") end # %{ # ["word1", "word2"] => %{"target" => weight} # } - chain = Map.update(chain, key, %{word => increment}, + chain_map = Map.update(chain.chain, key, %{word => increment}, fn weights -> Map.update(weights, word, increment, &(increment + &1)) end) - %Chain{chain: chain, order: order} + %Chain{chain | chain: chain_map} end def generate(chain) do diff --git a/test/contrib/markov/chain_test.exs b/test/contrib/markov/chain_test.exs index 78bdf01..729d183 100644 --- a/test/contrib/markov/chain_test.exs +++ b/test/contrib/markov/chain_test.exs @@ -68,6 +68,23 @@ defmodule MarkovChainTest do } end + test "chain add_weight does not reset reply_chance" do + chain = %Chain {order: 2, reply_chance: 0.0} + |> Chain.add_weight(["foo", "bar"], "baz") + + chain = chain |> Chain.add_weight(["foo", "bar"], "baz", 2) + assert chain.reply_chance == 0.0 + + chain = chain |> Chain.add_weight(["foo", "bar"], "qux") + assert chain.reply_chance == 0.0 + + chain = chain |> Chain.add_weight(["bar", "baz"], "qux") + assert chain.reply_chance == 0.0 + + chain = chain |> Chain.add_weight(["bar", "baz"], nil) + assert chain.reply_chance == 0.0 + end + test "chain merge works correctly" do chain1 = %Chain {order: 2} |> Chain.add_weight(["foo", "bar"], "baz")