Add dirty flag to markov chains

If something is changed in a markov chain it gets flagged as dirty,
which is used to determine whether the chain should be saved.

Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
2022-05-27 18:58:06 -07:00
parent cc30df8706
commit 887c8dc278

View File

@@ -40,6 +40,7 @@ class Chain:
self.path = path self.path = path
self.__cache = chain_default() self.__cache = chain_default()
self.__last_access = 0.0 self.__last_access = 0.0
self.__dirty = False
@property @property
def last_access(self) -> float: def last_access(self) -> float:
@@ -54,6 +55,7 @@ class Chain:
head = fragment[0:-1] head = fragment[0:-1]
tail = fragment[-1] tail = fragment[-1]
self.__cache[" ".join(head)][tail] += 1 self.__cache[" ".join(head)][tail] += 1
self.__dirty = True
def get(self, key: str) -> dict[str | None, int]: def get(self, key: str) -> dict[str | None, int]:
self.__last_access = asyncio.get_running_loop().time() self.__last_access = asyncio.get_running_loop().time()
@@ -73,6 +75,7 @@ class Chain:
# Attempt the cache before writing to it # Attempt the cache before writing to it
self.__load() self.__load()
self.__cache[key] = defaultdict(int, value) self.__cache[key] = defaultdict(int, value)
self.__dirty = True
def __load(self): def __load(self):
self.__last_access = asyncio.get_running_loop().time() self.__last_access = asyncio.get_running_loop().time()
@@ -95,14 +98,13 @@ class Chain:
for key, value in json.load(fp).items() for key, value in json.load(fp).items()
}, },
) )
self.__dirty = False
def save(self, retain: bool = True): def save(self, retain: bool = True):
if not self.__cache: if not self.__cache:
return return
if retain: if self.__dirty:
log.info("Saving markov chain to %s", self.path) log.info("Saving markov chain to %s", self.path)
else:
log.info("Saving markov chain to %s (not retaining)", self.path)
self.path.parent.mkdir(parents=True, exist_ok=True) self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, "w") as fp: with open(self.path, "w") as fp:
json.dump( json.dump(
@@ -115,11 +117,14 @@ class Chain:
}, },
fp, fp,
) )
self.__dirty = False
if not retain: if not retain:
log.debug("Pruning markov chain %s from memory", self.path)
self.clear_cache() self.clear_cache()
def clear_cache(self): def clear_cache(self):
self.__cache.clear() self.__cache.clear()
self.__dirty = False
def __bool__(self) -> bool: def __bool__(self) -> bool:
return self.path.exists() or bool(self.__cache) return self.path.exists() or bool(self.__cache)