from typing import Dict
[docs]
class BayesCategory:
"""
Represents a trainable category of content for bayesian classification
"""
def __init__(self, name: str):
"""
:param name: The name of the category we're creating
:type name: str
"""
self.name: str = name
self.tokens: Dict[str, int] = {}
self.tally: int = 0
[docs]
def train_token(self, word: str, count: int) -> None:
"""
Trains a particular token (increases the weight/count of it)
:param word: the token we're going to train
:type word: str
:param count: the number of occurrences in the sample
:type count: int
"""
if word not in self.tokens:
self.tokens[word] = 0
self.tokens[word] += count
self.tally += count
[docs]
def untrain_token(self, word: str, count: int) -> None:
"""
Untrains a particular token (decreases the weight/count of it)
:param word: the token we're going to train
:type word: str
:param count: the number of occurrences in the sample
:type count: int
"""
if word not in self.tokens:
return
# If we're trying to untrain more tokens than we have, we end at 0
count = min(count, self.tokens[word])
self.tokens[word] -= count
self.tally -= count
if self.tokens[word] <= 0:
del self.tokens[word]
[docs]
def get_token_count(self, word: str) -> int:
"""
Gets the count associated with a provided token/word
:param word: the token we're getting the weight of
:type word: str
:return: the weight/count of the token
:rtype: int
"""
return self.tokens.get(word, 0)
[docs]
def get_tally(self) -> int:
"""
Gets the tally of all types
:return: The total number of tokens
:rtype: int
"""
return self.tally