SDC-multi-classifier / embedding_cache.py
DocUA's picture
пофіксив завантаження JSON
9c959a8
raw
history blame
4.71 kB
import sqlite3
import numpy as np
from typing import Optional, List
from pathlib import Path
import time
class EmbeddingCache:
def __init__(self, db_path: str = "embeddings_cache.db"):
"""
Ініціалізація кешу ембедінгів
Args:
db_path: шлях до файлу SQLite бази даних
"""
self.db_path = db_path
self._init_db()
self.hits = 0
self.misses = 0
def _init_db(self):
"""Ініціалізація структури бази даних"""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
text_hash TEXT PRIMARY KEY,
text TEXT NOT NULL,
model TEXT NOT NULL,
embedding BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Індекс для швидкого пошуку за хешем
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_text_hash
ON embeddings(text_hash)
""")
def _get_hash(self, text: str, model: str) -> str:
"""Створення унікального хешу для тексту та моделі"""
return str(hash(f"{text}:{model}"))
def get(self, text: str, model: str) -> Optional[np.ndarray]:
"""
Отримання ембедінгу з кешу
Args:
text: текст для пошуку
model: назва моделі ембедінгів
Returns:
np.ndarray якщо знайдено, None якщо не знайдено
"""
text_hash = self._get_hash(text, model)
with sqlite3.connect(self.db_path) as conn:
result = conn.execute(
"SELECT embedding FROM embeddings WHERE text_hash = ?",
(text_hash,)
).fetchone()
if result:
self.hits += 1
return np.frombuffer(result[0], dtype=np.float32)
self.misses += 1
return None
def put(self, text: str, model: str, embedding: np.ndarray) -> None:
"""
Збереження ембедінгу в кеш
Args:
text: вхідний текст
model: назва моделі
embedding: ембедінг для збереження
"""
text_hash = self._get_hash(text, model)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT OR REPLACE INTO embeddings
(text_hash, text, model, embedding)
VALUES (?, ?, ?, ?)
""",
(
text_hash,
text,
model,
np.array(embedding, dtype=np.float32).tobytes()
)
)
def clear_old(self, days: int = 30) -> int:
"""
Очищення старих записів з кешу
Args:
days: кількість днів, старіші записи будуть видалені
Returns:
Кількість видалених записів
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"""
DELETE FROM embeddings
WHERE created_at < datetime('now', ?)
""",
(f"-{days} days",)
)
return cursor.rowcount
def get_stats(self) -> dict:
"""Отримання статистики використання кешу"""
with sqlite3.connect(self.db_path) as conn:
total = conn.execute(
"SELECT COUNT(*) FROM embeddings"
).fetchone()[0]
size = Path(self.db_path).stat().st_size / (1024 * 1024) # Size in MB
if self.hits + self.misses > 0:
hit_rate = self.hits / (self.hits + self.misses) * 100
else:
hit_rate = 0
return {
"total_entries": total,
"cache_size_mb": round(size, 2),
"hits": self.hits,
"misses": self.misses,
"hit_rate_percent": round(hit_rate, 2)
}