From 4c93b42fdcc95cde35f122bd937765f8d8905b7e Mon Sep 17 00:00:00 2001 From: Alek Ratzloff Date: Wed, 15 Jul 2020 16:25:25 -0700 Subject: [PATCH] Finish markov chain generation impl * Markov chains will train and generate chains correctly now * Implement Markov.save_chains/0 * Add a couple more utils that help accomplish the above Signed-off-by: Alek Ratzloff --- lib/contrib/markov/chain.ex | 49 ++++++++++++++++++++++++---- lib/contrib/markov/markov.ex | 51 ++++++++++++++++++++++-------- lib/util.ex | 23 ++++++++++++++ test/contrib/markov/chain_test.exs | 38 ++++++++++++++++++++++ test/util_test.exs | 10 ++++++ 5 files changed, 150 insertions(+), 21 deletions(-) diff --git a/lib/contrib/markov/chain.ex b/lib/contrib/markov/chain.ex index 95776b3..cd5175a 100644 --- a/lib/contrib/markov/chain.ex +++ b/lib/contrib/markov/chain.ex @@ -1,5 +1,5 @@ defmodule Omnibot.Contrib.Markov.Chain do - alias Omnibot.Contrib.Markov.Chain + alias Omnibot.{Contrib.Markov.Chain, Util} @enforce_keys [:order] defstruct order: 2, chain: [] @@ -14,17 +14,18 @@ defmodule Omnibot.Contrib.Markov.Chain do Enum.filter(words, &(String.length(&1) > 0)) |> Enum.chunk_every(order + 1, 1) # this gives us a "sliding window" effect |> Enum.reduce(chain, &case Enum.split(&1, order) do - {words, []} -> if length(words) == order, - # Null case for the chain; this is an "end" state - do: add_weight(&2, words, nil), - else: &2 # TODO ? train [a, nil] -> b ? + {words, []} when length(words) == order -> add_weight(&2, words, nil) + {words, []} -> add_weight(&2, Util.pad_trailing(words, nil, order), nil) {words, [next]} -> add_weight(&2, words, next) end) end def add_weight(%Chain {chain: chain, order: order}, key, word, increment \\ 1) do - if length(key) != order, do: raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") - chain = case Enum.find_index(chain, fn {listkey, _} -> listkey == key end) do + if length(key) != order do + raise(ArgumentError, message: "invalid key (length #{length(key)} vs. order #{order})") + end + + chain = case find_index(chain, key) do # Insert weight nil -> [{key, %{word => increment}} | chain] # Update weight @@ -36,4 +37,38 @@ defmodule Omnibot.Contrib.Markov.Chain do end %Chain{chain: chain, order: order} end + + def get(chain, key) when is_list(chain) do + item = Enum.find(chain, fn {listkey, _} -> listkey == key end) + case item do + nil -> nil + {_, weights} -> weights + end + end + + def get(chain, key), do: get(chain.chain, key) + + def find_index(chain, key) when is_list(chain) do + Enum.find_index(chain, fn {listkey, _} -> listkey == key end) + end + + def find_index(chain = %Chain{}, key), do: find_index(chain.chain, key) + + def generate(chain) do + {seed, _} = Stream.filter(chain.chain, fn {key, _} -> length(key) == chain.order end) + |> Enum.random() + generate(chain, seed) + end + + def generate(chain, key) do + do_generate(chain, key) |> Enum.join(" ") + end + + defp do_generate(_chain, [nil | _]), do: [] + + defp do_generate(chain, key) do + weights = get(chain, key) || [] + [next | key] = key ++ [Util.weighted_random(weights)] + [next | do_generate(chain, key)] + end end diff --git a/lib/contrib/markov/markov.ex b/lib/contrib/markov/markov.ex index ee769d2..08b99da 100644 --- a/lib/contrib/markov/markov.ex +++ b/lib/contrib/markov/markov.ex @@ -1,28 +1,44 @@ defmodule Omnibot.Contrib.Markov do use Omnibot.Plugin - alias Omnibot.Contrib.Markov.Chain + alias Omnibot.{Contrib.Markov.Chain, Util} require Logger - @default_config path: "markov", order: 2, save_every: 5 * 60 + @default_config path: "markov.ets", order: 2, save_every: 5 * 60 + + command "!markov", ["force"] do + # Choose a random value from the sender + Irc.send_to(irc, channel, "TODO") + end + + command "!markov", ["all"] do + Irc.send_to(irc, channel, "TODO") + end + + command "!markov", ["status"] do + Irc.send_to(irc, channel, "TODO") + end @impl true def children(cfg) do - [{Task, fn -> - Stream.timer(cfg[:save_every] * 1000) - |> Stream.cycle() - |> Stream.each(fn _ -> save_chains() end) - |> Stream.run() - end}] + [ + {Task, fn -> + Stream.timer(cfg[:save_every] * 1000) + |> Stream.cycle() + |> Stream.each(fn _ -> save_chains() end) + |> Stream.run() + end} + ] end @impl true def on_init(_cfg) do # Create the markov database path = String.to_atom(cfg()[:path]) - {:ok, db} = :dets.open_file(path, [:named_table]) - chains = :ets.new(:markov_chains, [:public]) + {:ok, db} = :dets.open_file(path, []) + chains = :ets.new(:markov_chains, [:named_table, :public]) :dets.to_ets(db, chains) - :dets.close(db) + :ok = :dets.close(db) + chains end @impl true @@ -48,7 +64,7 @@ defmodule Omnibot.Contrib.Markov do db = state() case user_chain(channel, user) do nil -> :ets.insert_new(db, {{channel, user}, chain}) - chain -> :ets.insert(db, {{channel, user}, chain}) + _old_chain -> :ets.insert(db, {{channel, user}, chain}) end end @@ -58,7 +74,14 @@ defmodule Omnibot.Contrib.Markov do end def save_chains() do - # TODO - Logger.info("Saved markov chains") + start = Util.now_unix() + Logger.debug("Saving markov chains") + + {:ok, db} = :dets.open_file(cfg()[:path], []) + :ets.to_dets(state(), db) + :ok = :dets.close(db) + + stop = Util.now_unix() + Logger.info("Saved markov chains in #{stop - start} seconds") end end diff --git a/lib/util.ex b/lib/util.ex index af710e6..986274c 100644 --- a/lib/util.ex +++ b/lib/util.ex @@ -14,4 +14,27 @@ defmodule Omnibot.Util do def denotify_nick(nick) do String.graphemes(nick) |> Enum.join("\u200b") end + + def weighted_random(items) when is_map(items) do + Enum.to_list(items) |> weighted_random() + end + + def weighted_random([]), do: nil + + def weighted_random(items) do + value = items + |> Enum.reduce(0, fn {_, weight}, total -> total + weight end) + |> :rand.uniform() + select_item(items, value) + end + + defp select_item([{item, _}], _), do: item + + defp select_item([{item, weight} | _], index) when weight >= index, do: item + + defp select_item([{_, weight} | tail], index), do: select_item(tail, index - weight) + + def pad_trailing(list, _what, len) when length(list) >= len, do: list + + def pad_trailing(list, what, len), do: pad_trailing(list ++ [what], what, len) end diff --git a/test/contrib/markov/chain_test.exs b/test/contrib/markov/chain_test.exs index 89287b1..67253a8 100644 --- a/test/contrib/markov/chain_test.exs +++ b/test/contrib/markov/chain_test.exs @@ -9,6 +9,32 @@ defmodule MarkovChainTest do {["bar", "baz"], %{nil => 1}}, {["foo", "bar"], %{"baz" => 1}}, ] + + chain = chain |> Chain.train(~w(foo bar baz)) + + assert chain.chain == [ + {["bar", "baz"], %{nil => 2}}, + {["foo", "bar"], %{"baz" => 2}}, + ] + + chain = chain |> Chain.train(~w(baz bar foo)) + + assert chain.chain == [ + {["bar", "foo"], %{nil => 1}}, + {["baz", "bar"], %{"foo" => 1}}, + {["bar", "baz"], %{nil => 2}}, + {["foo", "bar"], %{"baz" => 2}}, + ] + + chain = chain |> Chain.train(~w(a b c)) + assert chain.chain == [ + {["b", "c"], %{nil => 1}}, + {["a", "b"], %{"c" => 1}}, + {["bar", "foo"], %{nil => 1}}, + {["baz", "bar"], %{"foo" => 1}}, + {["bar", "baz"], %{nil => 2}}, + {["foo", "bar"], %{"baz" => 2}}, + ] end test "chain add_weight works correctly" do @@ -28,5 +54,17 @@ defmodule MarkovChainTest do assert chain.chain == [ {["foo", "bar"], %{"baz" => 3, "qux" => 1}} ] + + chain = chain |> Chain.add_weight(["bar", "baz"], "qux") + assert chain.chain == [ + {["bar", "baz"], %{"qux" => 1}}, + {["foo", "bar"], %{"baz" => 3, "qux" => 1}}, + ] + + chain = chain |> Chain.add_weight(["bar", "baz"], nil) + assert chain.chain == [ + {["bar", "baz"], %{"qux" => 1, nil => 1}}, + {["foo", "bar"], %{"baz" => 3, "qux" => 1}}, + ] end end diff --git a/test/util_test.exs b/test/util_test.exs index d7921e5..07ac11e 100644 --- a/test/util_test.exs +++ b/test/util_test.exs @@ -12,4 +12,14 @@ defmodule Omnibot.UtilTest do assert Util.string_or_nil("") == nil assert Util.string_or_nil("asdf") == "asdf" end + + test "pad_trailing" do + assert Util.pad_trailing([1, 2, 3, 4], nil, 7) == [1, 2, 3, 4, nil, nil, nil] + assert Util.pad_trailing([1, 2, 3, 4], nil, 6) == [1, 2, 3, 4, nil, nil] + assert Util.pad_trailing([1, 2, 3, 4], nil, 5) == [1, 2, 3, 4, nil] + assert Util.pad_trailing([1, 2, 3, 4], nil, 4) == [1, 2, 3, 4] + assert Util.pad_trailing([1, 2, 3, 4], nil, 3) == [1, 2, 3, 4] + assert Util.pad_trailing([1, 2, 3, 4], nil, 2) == [1, 2, 3, 4] + assert Util.pad_trailing([1, 2, 3, 4], nil, 1) == [1, 2, 3, 4] + end end