Add dirty flag to markov chains
If something is changed in a markov chain it gets flagged as dirty, which is used to determine whether the chain should be saved. Signed-off-by: Alek Ratzloff <alekratz@gmail.com>
This commit is contained in:
@@ -40,6 +40,7 @@ class Chain:
|
|||||||
self.path = path
|
self.path = path
|
||||||
self.__cache = chain_default()
|
self.__cache = chain_default()
|
||||||
self.__last_access = 0.0
|
self.__last_access = 0.0
|
||||||
|
self.__dirty = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_access(self) -> float:
|
def last_access(self) -> float:
|
||||||
@@ -54,6 +55,7 @@ class Chain:
|
|||||||
head = fragment[0:-1]
|
head = fragment[0:-1]
|
||||||
tail = fragment[-1]
|
tail = fragment[-1]
|
||||||
self.__cache[" ".join(head)][tail] += 1
|
self.__cache[" ".join(head)][tail] += 1
|
||||||
|
self.__dirty = True
|
||||||
|
|
||||||
def get(self, key: str) -> dict[str | None, int]:
|
def get(self, key: str) -> dict[str | None, int]:
|
||||||
self.__last_access = asyncio.get_running_loop().time()
|
self.__last_access = asyncio.get_running_loop().time()
|
||||||
@@ -73,6 +75,7 @@ class Chain:
|
|||||||
# Attempt the cache before writing to it
|
# Attempt the cache before writing to it
|
||||||
self.__load()
|
self.__load()
|
||||||
self.__cache[key] = defaultdict(int, value)
|
self.__cache[key] = defaultdict(int, value)
|
||||||
|
self.__dirty = True
|
||||||
|
|
||||||
def __load(self):
|
def __load(self):
|
||||||
self.__last_access = asyncio.get_running_loop().time()
|
self.__last_access = asyncio.get_running_loop().time()
|
||||||
@@ -95,31 +98,33 @@ class Chain:
|
|||||||
for key, value in json.load(fp).items()
|
for key, value in json.load(fp).items()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.__dirty = False
|
||||||
|
|
||||||
def save(self, retain: bool = True):
|
def save(self, retain: bool = True):
|
||||||
if not self.__cache:
|
if not self.__cache:
|
||||||
return
|
return
|
||||||
if retain:
|
if self.__dirty:
|
||||||
log.info("Saving markov chain to %s", self.path)
|
log.info("Saving markov chain to %s", self.path)
|
||||||
else:
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
log.info("Saving markov chain to %s (not retaining)", self.path)
|
with open(self.path, "w") as fp:
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
json.dump(
|
||||||
with open(self.path, "w") as fp:
|
{
|
||||||
json.dump(
|
key: {
|
||||||
{
|
("" if word is None else word): weight
|
||||||
key: {
|
for word, weight in value.items()
|
||||||
("" if word is None else word): weight
|
}
|
||||||
for word, weight in value.items()
|
for key, value in self.__cache.items()
|
||||||
}
|
},
|
||||||
for key, value in self.__cache.items()
|
fp,
|
||||||
},
|
)
|
||||||
fp,
|
self.__dirty = False
|
||||||
)
|
|
||||||
if not retain:
|
if not retain:
|
||||||
|
log.debug("Pruning markov chain %s from memory", self.path)
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
|
|
||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
self.__cache.clear()
|
self.__cache.clear()
|
||||||
|
self.__dirty = False
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return self.path.exists() or bool(self.__cache)
|
return self.path.exists() or bool(self.__cache)
|
||||||
|
|||||||
Reference in New Issue
Block a user