markov: Remove allchain
The allchain has been a source of headaches because it takes up a lot of memory and slows everything down. However, with the new database model, we can generate markov sentences using all of the rows since they are a flat collection. This helps reduce disk space and increases the import speed significantly. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any, List, Sequence
|
from typing import Any, DefaultDict, List, Sequence
|
||||||
|
|
||||||
from asyncirc.protocol import IrcProtocol
|
from asyncirc.protocol import IrcProtocol
|
||||||
from irclib.parser import Prefix
|
from irclib.parser import Prefix
|
||||||
@@ -13,7 +13,6 @@ from omnibot.plugin import Plugin
|
|||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
ALLCHAIN = "ALL!CHAIN"
|
|
||||||
|
|
||||||
|
|
||||||
def chain_inner_default() -> defaultdict[str | None, int]:
|
def chain_inner_default() -> defaultdict[str | None, int]:
|
||||||
@@ -191,6 +190,48 @@ class Chain:
|
|||||||
return ""
|
return ""
|
||||||
return random.choices(words, weights)[0]
|
return random.choices(words, weights)[0]
|
||||||
|
|
||||||
|
def generate_all(self) -> str | None:
|
||||||
|
words: List[str] = []
|
||||||
|
cursor = self.db.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT value
|
||||||
|
FROM chain
|
||||||
|
ORDER BY RANDOM()
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
node = cursor.fetchone()[0].split(" ")
|
||||||
|
words += node
|
||||||
|
next: str = self.all_choose_next(" ".join(node))
|
||||||
|
while next:
|
||||||
|
words += [next]
|
||||||
|
node = [*node[1:], next]
|
||||||
|
next = self.all_choose_next(" ".join(node))
|
||||||
|
return " ".join(words)
|
||||||
|
|
||||||
|
def all_choose_next(self, head: str) -> str:
|
||||||
|
cursor = self.db.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT next, weight
|
||||||
|
FROM chain
|
||||||
|
WHERE value = ?
|
||||||
|
ORDER BY RANDOM()
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(head,),
|
||||||
|
)
|
||||||
|
choices: DefaultDict[str, int] = defaultdict(lambda: 0)
|
||||||
|
# Collapse all choices by weight
|
||||||
|
for next, weight in cursor.fetchall():
|
||||||
|
choices[next] += weight
|
||||||
|
if not choices:
|
||||||
|
return ""
|
||||||
|
words = list(choices.keys())
|
||||||
|
weights = list(choices.values())
|
||||||
|
return random.choices(words, weights)[0]
|
||||||
|
|
||||||
|
|
||||||
class Markov(Plugin):
|
class Markov(Plugin):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -226,7 +267,6 @@ class Markov(Plugin):
|
|||||||
if who == self.server_config.nick:
|
if who == self.server_config.nick:
|
||||||
return
|
return
|
||||||
self.chain.add(channel, who, line, commit)
|
self.chain.add(channel, who, line, commit)
|
||||||
self.chain.add(channel, ALLCHAIN, line, commit)
|
|
||||||
|
|
||||||
def handle_command(
|
def handle_command(
|
||||||
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
self, conn: IrcProtocol, channel: str, who: Prefix, parts: Sequence[str]
|
||||||
@@ -244,7 +284,7 @@ class Markov(Plugin):
|
|||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
case ["all"]:
|
case ["all"]:
|
||||||
message = self.chain.generate(channel, ALLCHAIN)
|
message = self.chain.generate_all()
|
||||||
if message:
|
if message:
|
||||||
self.send_to(conn, channel, f"{who.nick}: {message}")
|
self.send_to(conn, channel, f"{who.nick}: {message}")
|
||||||
case ["chance"]:
|
case ["chance"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user