# coding: utf-8
__version__ = '3.2.0'
import threading
from collections import Counter
from typing import Callable, Dict, List, Optional
from simplebayes.categories import BayesCategories
from simplebayes.constants import CATEGORY_PATTERN
from simplebayes.errors import InvalidCategoryError
from simplebayes.models import CategorySummary, ClassificationResult
from simplebayes.persistence import (
PERSISTED_MODEL_VERSION,
dump_model_state,
load_model_state,
load_model_state_from_file,
save_model_state_to_file,
validate_model_state,
)
from simplebayes.tokenization import create_tokenizer, default_tokenize_text
__all__ = ['SimpleBayes']
[docs]
class SimpleBayes:
"""A memory-based, optional-persistence naïve bayesian text classifier."""
def __init__(
self,
tokenizer: Optional[Callable[[str], List[str]]] = None,
alpha: float = 0.0,
language: str = "english",
remove_stop_words: bool = False,
) -> None:
"""
:param tokenizer: A tokenizer override. When None, uses built-in tokenizer.
:param alpha: Laplace smoothing parameter. Use > 0 (e.g. 0.01 or 1.0) to avoid
zero probabilities for tokens unseen in a category. Default 0 preserves
prior behavior.
:param language: Language code for stemmer and stop words (e.g. "english",
"spanish"). Default "english".
:param remove_stop_words: If True, filter stop words. Default False (backwards compatible).
"""
self.categories = BayesCategories()
self.tokenizer = (
tokenizer
or create_tokenizer(language=language, remove_stop_words=remove_stop_words)
)
self.alpha = alpha
self.probabilities = {}
self._lock = threading.RLock()
[docs]
@classmethod
def tokenize_text(cls, text: str) -> List[str]:
"""
Default tokenize method; can be overridden
:param text: the text we want to tokenize
:type text: str
:return: list of tokenized text
:rtype: list
"""
return default_tokenize_text(text)
[docs]
@classmethod
def count_token_occurrences(cls, words: List[str]) -> Dict[str, int]:
"""
Creates a key/value set of word/count for a given sample of text
:param words: full list of all tokens, non-unique
:type words: list
:return: key/value pairs of words and their counts in the list
:rtype: dict
"""
return dict(Counter(words))
[docs]
def flush(self) -> None:
"""
Deletes all tokens & categories
"""
with self._lock:
self.categories = BayesCategories()
self.probabilities = {}
[docs]
def calculate_category_probability(self) -> None:
"""
Caches the individual probabilities for each category
"""
with self._lock:
total_tally = 0.0
probs = {}
for category, bayes_category in \
self.categories.get_categories().items():
count = bayes_category.get_tally()
total_tally += count
probs[category] = count
# Calculating the probability
for category, count in probs.items():
if total_tally > 0:
probs[category] = float(count)/float(total_tally)
else:
probs[category] = 0.0
new_probabilities = {}
for category, probability in probs.items():
new_probabilities[category] = {
# Probability that any given token is of this category
'prc': probability,
# Probability that any given token is not of this category
'prnc': 1.0 - probability
}
self.probabilities = new_probabilities
[docs]
def train(self, category: str, text: str) -> None:
"""
Trains a category with a sample of text
:param category: the name of the category we want to train
:type category: str
:param text: the text we want to train the category with
:type text: str
"""
category = self.normalize_category(category)
with self._lock:
try:
bayes_category = self.categories.get_category(category)
except KeyError:
bayes_category = self.categories.add_category(category)
tokens = self.tokenizer(str(text))
occurrence_counts = self.count_token_occurrences(tokens)
for word, count in occurrence_counts.items():
bayes_category.train_token(word, count)
# Updating our per-category overall probabilities
self.calculate_category_probability()
[docs]
def untrain(self, category: str, text: str) -> None:
"""
Untrains a category with a sample of text
:param category: the name of the category we want to train
:type category: str
:param text: the text we want to untrain the category with
:type text: str
"""
category = self.normalize_category(category)
with self._lock:
try:
bayes_category = self.categories.get_category(category)
except KeyError:
return
tokens = self.tokenizer(str(text))
occurrence_counts = self.count_token_occurrences(tokens)
for word, count in occurrence_counts.items():
bayes_category.untrain_token(word, count)
if bayes_category.get_tally() == 0:
self.categories.delete_category(category)
# Updating our per-category overall probabilities
self.calculate_category_probability()
[docs]
def classify(self, text: str) -> Optional[str]:
"""
Chooses the highest scoring category for a sample of text
:param text: sample text to classify
:type text: str
:return: the "winning" category
:rtype: str
"""
with self._lock:
score = self.score(text)
highest_category, _ = self._find_highest_category(score)
return highest_category
[docs]
def classify_result(self, text: str) -> ClassificationResult:
"""
Returns structured classification output including score.
"""
with self._lock:
scores = self.score(text)
highest_category, highest_score = self._find_highest_category(scores)
return ClassificationResult(category=highest_category or None, score=highest_score)
@classmethod
def _find_highest_category(cls, scores: Dict[str, float]) -> tuple[Optional[str], float]:
if not scores:
return None, 0.0
highest_category = None
highest_score = 0.0
for category in sorted(scores.keys()):
category_score = float(scores[category])
if category_score > highest_score:
highest_score = category_score
highest_category = category
return highest_category, highest_score
[docs]
def score(self, text: str) -> Dict[str, float]:
"""
Scores a sample of text
:param text: sample text to score
:type text: str
:return: dict of scores per category
:rtype: dict
"""
with self._lock:
occurs = self.count_token_occurrences(self.tokenizer(text))
scores = {}
for category in self.categories.get_categories():
scores[category] = 0
categories = self.categories.get_categories().items()
for word, count in occurs.items():
token_scores = {}
# Adding up individual token scores
for category, bayes_category in categories:
token_scores[category] = \
float(bayes_category.get_token_count(word))
# We use this to get token-in-category probabilities
token_tally = sum(token_scores.values())
# If this token isn't found anywhere its probability is 0
if token_tally == 0.0:
continue
# Calculating bayes probability for this token
# http://en.wikipedia.org/wiki/Naive_Bayes_spam_filtering
for category, token_score in token_scores.items():
# Bayes probability * the number of occurrences of this token
scores[category] += count * \
self.calculate_bayesian_probability(
category,
token_score,
token_tally
)
# Removing empty categories from the results
final_scores = {}
for category, score in scores.items():
if score > 0:
final_scores[category] = score
return final_scores
[docs]
def calculate_bayesian_probability(
self, cat: str, token_score: float, token_tally: float
) -> float:
"""
Calculates the bayesian probability for a given token/category
:param cat: The category we're scoring for this token
:type cat: str
:param token_score: The tally of this token for this category
:type token_score: float
:param token_tally: The tally total for this token from all categories
:type token_tally: float
:return: bayesian probability
:rtype: float
"""
# P that any given token IS in this category
prc = self.probabilities[cat]['prc']
# P that any given token is NOT in this category
prnc = self.probabilities[cat]['prnc']
# Laplace smoothing: add alpha to avoid zero probabilities
# (token_in_cat, token_not_in_cat) -> k=2 for binary view per token
if self.alpha > 0:
prtc = (token_score + self.alpha) / (token_tally + 2.0 * self.alpha)
prtnc = (token_tally - token_score + self.alpha) / (
token_tally + 2.0 * self.alpha
)
else:
prtnc = (token_tally - token_score) / token_tally
prtc = token_score / token_tally
# Assembling the parts of the bayes equation
numerator = prtc * prc
denominator = numerator + (prtnc * prnc)
# Returning the calculated bayes probability unless the denom. is 0
return numerator / denominator if denominator != 0.0 else 0.0
[docs]
def tally(self, category: str) -> int:
"""
Gets the tally for a requested category
:param category: The category we want a tally for
:type category: str
:return: tally for a given category
:rtype: int
"""
with self._lock:
try:
bayes_category = self.categories.get_category(category)
except KeyError:
return 0
return bayes_category.get_tally()
[docs]
def get_summaries(self) -> Dict[str, CategorySummary]:
"""
Returns per-category summary details.
"""
with self._lock:
summaries: Dict[str, CategorySummary] = {}
categories = self.categories.get_categories()
for category_name, category in categories.items():
category_probability = self.probabilities.get(
category_name,
{'prc': 0.0, 'prnc': 0.0},
)
summaries[category_name] = CategorySummary(
token_tally=category.get_tally(),
prob_in_cat=float(category_probability['prc']),
prob_not_in_cat=float(category_probability['prnc']),
)
return summaries
[docs]
def save(self, destination) -> None:
"""
Saves classifier state to a text stream.
"""
with self._lock:
dump_model_state(destination, self._export_model_state())
[docs]
def load(self, source) -> None:
"""
Loads classifier state from a text stream.
"""
with self._lock:
state = load_model_state(source)
validate_model_state(state)
self._apply_model_state(state)
[docs]
def save_to_file(self, absolute_path: str = "") -> None:
"""
Saves classifier state to file using atomic replacement.
"""
with self._lock:
save_model_state_to_file(absolute_path, self._export_model_state())
[docs]
def load_from_file(self, absolute_path: str = "") -> None:
"""
Loads classifier state from a persisted model file.
"""
with self._lock:
state = load_model_state_from_file(absolute_path)
validate_model_state(state)
self._apply_model_state(state)
[docs]
@classmethod
def normalize_category(cls, category: str | None) -> str:
"""
Validates and normalizes category input.
"""
if category is None:
raise InvalidCategoryError("category is required")
normalized = str(category).strip()
if not CATEGORY_PATTERN.match(normalized):
raise InvalidCategoryError(
"category must be 1-64 chars and only include letters, numbers, underscore, or hyphen",
)
return normalized
def _export_model_state(self) -> Dict:
categories = {}
for category_name, category in self.categories.get_categories().items():
category_tokens = {
token: int(count)
for token, count in category.tokens.items()
if count > 0
}
categories[category_name] = {
"tally": int(category.get_tally()),
"tokens": category_tokens,
}
return {
"version": PERSISTED_MODEL_VERSION,
"categories": categories,
}
def _apply_model_state(self, state: Dict) -> None:
self.categories = BayesCategories()
categories = state.get("categories", {})
for category_name, category_state in categories.items():
category = self.categories.add_category(category_name)
for token, count in category_state["tokens"].items():
category.train_token(token, count)
self.calculate_category_probability()