DocUA commited on
Commit
aaec566
·
1 Parent(s): cd3968b

добавлення функціонала для підключення моделей для локального ембедінга

Browse files
Files changed (4) hide show
  1. app.py +132 -41
  2. local_embedder.py +113 -0
  3. requirements.txt +3 -1
  4. sdc_classifier.py +163 -81
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from sdc_classifier import SDCClassifier
3
  from dotenv import load_dotenv
 
4
  import json
5
  import os
6
 
@@ -35,6 +36,24 @@ def initialize_environment():
35
 
36
  return True
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def main():
39
  # Константи файлів
40
  DEFAULT_CLASSES_FILE = "classes.json"
@@ -58,7 +77,7 @@ def main():
58
  try:
59
  classifier.load_initial_state(DEFAULT_CLASSES_FILE, DEFAULT_SIGNATURES_FILE)
60
  result = classifier.initialize_signatures(
61
- force_rebuild=True, # Примусово будуємо нові signatures
62
  signatures_file=DEFAULT_SIGNATURES_FILE
63
  )
64
  print(f"Результат ініціалізації: {result}")
@@ -101,11 +120,33 @@ def main():
101
  # Налаштування моделі
102
  with gr.Accordion("Налаштування моделі", open=False):
103
  with gr.Row():
 
 
 
 
 
104
  model_choice = gr.Dropdown(
105
- choices=["text-embedding-3-large","text-embedding-3-small"],
 
 
 
106
  value="text-embedding-3-large",
107
- label="OpenAI model"
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
 
 
109
  json_file = gr.File(
110
  label="Завантажити новий JSON з класами",
111
  file_types=[".json"]
@@ -114,6 +155,7 @@ def main():
114
  label="Примусово перебудувати signatures",
115
  value=False
116
  )
 
117
  with gr.Row():
118
  build_btn = gr.Button("Оновити signatures")
119
  build_out = gr.Label(label="Статус signatures")
@@ -156,80 +198,129 @@ def main():
156
  gr.Markdown("""
157
  ### Інструкція:
158
  1. У вкладці "Налаштування моделі" можна:
 
 
159
  - Завантажити новий JSON файл з класами
160
- - Вибрати модель для embeddings
161
  - Примусово перебудувати signatures
162
- 2. Після зміни класів натисніть "Оновити signatures"
163
  3. Використовуйте повзунок "Поріг впевненості" для фільтрації результатів
164
  4. На вкладці "Пакетна обробка" можна аналізувати багато повідомлень
165
  5. Результати можна зберегти в CSV файл
166
  """)
167
 
168
  # Підключення обробників подій
169
- def update_with_file(file, model_name, force):
170
- if file is None:
171
- # Відновлюємо базовий стан якщо файл видалено
172
- classifier.restore_base_state()
173
- return ("Відновлено базовий набір класів", classifier.get_cache_stats())
174
-
 
 
 
 
 
175
  try:
176
- # Для роботи з gradio File компонентом
177
- if hasattr(file, 'name'): # Якщо це файловий об'єкт
178
- with open(file.name, 'r', encoding='utf-8') as f:
179
- new_classes = json.load(f)
180
- else: # Якщо це строка
181
- new_classes = json.loads(file)
182
-
183
- if not isinstance(new_classes, dict):
184
- return ("Помилка: JSON повинен містити словник класів", classifier.get_cache_stats())
185
 
186
- # Завантажуємо нові класи без перезапису файлу за замовчуванням
187
- classifier.load_classes(new_classes)
 
 
 
 
 
188
 
189
- # Створюємо тимчасові signatures
190
  result = classifier.initialize_signatures(
191
- model_name=model_name,
192
- signatures_file=None, # Не зберігаємо у файл
193
- force_rebuild=True # Завжди перебудовуємо для нових класів
194
  )
195
 
196
- return (f"Тимчасові класи завантажено. {result}", classifier.get_cache_stats())
197
-
198
- except json.JSONDecodeError:
199
- return ("Помилка: Неправильний формат JSON файлу", classifier.get_cache_stats())
200
  except Exception as e:
201
- return (f"Помилка при оновленні: {str(e)}", classifier.get_cache_stats())
202
-
203
- single_process_btn.click(
204
- fn=lambda text, threshold: classifier.process_single_text(text, threshold),
205
- inputs=[text_input, threshold_slider],
206
- outputs=result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  )
208
 
209
  build_btn.click(
210
- fn=update_with_file,
211
- inputs=[json_file, model_choice, force_rebuild],
 
 
 
 
 
 
 
212
  outputs=[build_out, cache_stats]
213
  )
214
 
 
 
 
 
 
 
215
  load_btn.click(
216
- fn=lambda csv, emb: classifier.load_data(csv, emb),
217
  inputs=[csv_input, emb_input],
218
  outputs=load_output
219
  )
220
 
221
  classify_btn.click(
222
- fn=lambda filter_str, threshold: classifier.classify_rows(filter_str, threshold),
223
  inputs=[filter_in, batch_threshold],
224
  outputs=classify_out
225
  )
226
 
227
  save_btn.click(
228
- fn=lambda: classifier.save_results("messages_with_labels.csv"),
229
  inputs=[],
230
  outputs=save_out
231
  )
232
 
 
233
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
234
 
235
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from sdc_classifier import SDCClassifier
3
  from dotenv import load_dotenv
4
+ import torch
5
  import json
6
  import os
7
 
 
36
 
37
  return True
38
 
39
+ def create_classifier(model_type, openai_model=None, local_model=None, device=None):
40
+ """
41
+ Створення класифікатора з відповідними параметрами
42
+
43
+ Args:
44
+ model_type: тип моделі ("OpenAI" або "Local")
45
+ openai_model: назва моделі OpenAI
46
+ local_model: шлях до локальної моделі
47
+ device: пристрій для локальної моделі
48
+
49
+ Returns:
50
+ SDCClassifier: налаштований класифікатор
51
+ """
52
+ if model_type == "OpenAI":
53
+ return SDCClassifier()
54
+ else:
55
+ return SDCClassifier(local_model=local_model, device=device)
56
+
57
  def main():
58
  # Константи файлів
59
  DEFAULT_CLASSES_FILE = "classes.json"
 
77
  try:
78
  classifier.load_initial_state(DEFAULT_CLASSES_FILE, DEFAULT_SIGNATURES_FILE)
79
  result = classifier.initialize_signatures(
80
+ force_rebuild=True,
81
  signatures_file=DEFAULT_SIGNATURES_FILE
82
  )
83
  print(f"Результат ініціалізації: {result}")
 
120
  # Налаштування моделі
121
  with gr.Accordion("Налаштування моделі", open=False):
122
  with gr.Row():
123
+ model_type = gr.Radio(
124
+ choices=["OpenAI", "Local"],
125
+ value="OpenAI",
126
+ label="Тип моделі"
127
+ )
128
  model_choice = gr.Dropdown(
129
+ choices=[
130
+ "text-embedding-3-large",
131
+ "text-embedding-3-small"
132
+ ],
133
  value="text-embedding-3-large",
134
+ label="OpenAI model",
135
+ visible=True
136
+ )
137
+ local_model_path = gr.Textbox(
138
+ value="cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
139
+ label="Шлях до локальної моделі",
140
+ visible=False
141
+ )
142
+ device_choice = gr.Radio(
143
+ choices=["cuda", "cpu"],
144
+ value="cuda" if torch.cuda.is_available() else "cpu",
145
+ label="Пристрій для локальної моделі",
146
+ visible=False
147
  )
148
+
149
+ with gr.Row():
150
  json_file = gr.File(
151
  label="Завантажити новий JSON з класами",
152
  file_types=[".json"]
 
155
  label="Примусово перебудувати signatures",
156
  value=False
157
  )
158
+
159
  with gr.Row():
160
  build_btn = gr.Button("Оновити signatures")
161
  build_out = gr.Label(label="Статус signatures")
 
198
  gr.Markdown("""
199
  ### Інструкція:
200
  1. У вкладці "Налаштування моделі" можна:
201
+ - Вибрати тип моделі (OpenAI або Local)
202
+ - Налаштувати параметри вибраної моделі
203
  - Завантажити новий JSON файл з класами
 
204
  - Примусово перебудувати signatures
205
+ 2. Після зміни налаштувань натисніть "Оновити signatures"
206
  3. Використовуйте повзунок "Поріг впевненості" для фільтрації результатів
207
  4. На вкладці "Пакетна обробка" можна аналізувати багато повідомлень
208
  5. Результати можна зберегти в CSV файл
209
  """)
210
 
211
  # Підключення обробників подій
212
+ def update_model_inputs(model_type):
213
+ """Оновлення видимості полів в залежності від типу моделі"""
214
+ return {
215
+ model_choice: gr.update(visible=model_type == "OpenAI"),
216
+ local_model_path: gr.update(visible=model_type == "Local"),
217
+ device_choice: gr.update(visible=model_type == "Local")
218
+ }
219
+
220
+ def update_classifier_settings(json_file, model_type, openai_model,
221
+ local_model, device, force_rebuild):
222
+ """Оновлення налаштувань класифікатора"""
223
  try:
224
+ # Створюємо новий класифікатор з вибраними параметрами
225
+ nonlocal classifier
226
+ classifier = create_classifier(
227
+ model_type=model_type,
228
+ openai_model=openai_model if model_type == "OpenAI" else None,
229
+ local_model=local_model if model_type == "Local" else None,
230
+ device=device if model_type == "Local" else None
231
+ )
 
232
 
233
+ # Завантажуємо класи
234
+ if json_file is not None:
235
+ with open(json_file.name, 'r', encoding='utf-8') as f:
236
+ new_classes = json.load(f)
237
+ classifier.load_classes(new_classes)
238
+ else:
239
+ classifier.restore_base_state()
240
 
241
+ # Ініціалізуємо signatures
242
  result = classifier.initialize_signatures(
243
+ force_rebuild=force_rebuild,
244
+ signatures_file=DEFAULT_SIGNATURES_FILE if not force_rebuild else None
 
245
  )
246
 
247
+ return result, classifier.get_cache_stats()
 
 
 
248
  except Exception as e:
249
+ return f"Помилка: {str(e)}", classifier.get_cache_stats()
250
+
251
+ def process_single_text(text, threshold):
252
+ """Обробка одного тексту"""
253
+ try:
254
+ return classifier.process_single_text(text, threshold)
255
+ except Exception as e:
256
+ return {"error": str(e)}
257
+
258
+ def load_data(csv_path, emb_path):
259
+ """Завантаження даних для пакетної обробки"""
260
+ try:
261
+ return classifier.load_data(csv_path, emb_path)
262
+ except Exception as e:
263
+ return f"Помилка: {str(e)}"
264
+
265
+ def classify_batch(filter_str, threshold):
266
+ """Пакетна класифікація"""
267
+ try:
268
+ return classifier.classify_rows(filter_str, threshold)
269
+ except Exception as e:
270
+ return None
271
+
272
+ def save_results():
273
+ """Збереження результатів"""
274
+ try:
275
+ return classifier.save_results()
276
+ except Exception as e:
277
+ return f"Помилка: {str(e)}"
278
+
279
+ # Підключення подій
280
+ model_type.change(
281
+ fn=update_model_inputs,
282
+ inputs=[model_type],
283
+ outputs=[model_choice, local_model_path, device_choice]
284
  )
285
 
286
  build_btn.click(
287
+ fn=update_classifier_settings,
288
+ inputs=[
289
+ json_file,
290
+ model_type,
291
+ model_choice,
292
+ local_model_path,
293
+ device_choice,
294
+ force_rebuild
295
+ ],
296
  outputs=[build_out, cache_stats]
297
  )
298
 
299
+ single_process_btn.click(
300
+ fn=process_single_text,
301
+ inputs=[text_input, threshold_slider],
302
+ outputs=result_text
303
+ )
304
+
305
  load_btn.click(
306
+ fn=load_data,
307
  inputs=[csv_input, emb_input],
308
  outputs=load_output
309
  )
310
 
311
  classify_btn.click(
312
+ fn=classify_batch,
313
  inputs=[filter_in, batch_threshold],
314
  outputs=classify_out
315
  )
316
 
317
  save_btn.click(
318
+ fn=save_results,
319
  inputs=[],
320
  outputs=save_out
321
  )
322
 
323
+ # Запуск веб-інтерфейсу
324
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
325
 
326
  if __name__ == "__main__":
local_embedder.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import List, Union, Dict
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from pathlib import Path
6
+ import json
7
+
8
+ class LocalEmbedder:
9
+ def __init__(self, model_name: str, device: str = None, batch_size: int = 32):
10
+ """
11
+ Ініціалізація локальної моделі для ембедінгів
12
+
13
+ Args:
14
+ model_name: назва або шлях до моделі (з HuggingFace або локальна)
15
+ device: пристрій для обчислень ('cuda', 'cpu' або None - автовибір)
16
+ batch_size: розмір батчу для інференсу
17
+ """
18
+ self.model_name = model_name
19
+ self.batch_size = batch_size
20
+
21
+ # Визначення пристрою
22
+ if device is None:
23
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
+ else:
25
+ self.device = device
26
+
27
+ # Завантаження моделі та токенізатора
28
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ self.model = AutoModel.from_pretrained(model_name).to(self.device)
30
+ self.model.eval()
31
+
32
+ # Максимальна довжина послідовності
33
+ self.max_length = self.tokenizer.model_max_length
34
+ if self.max_length > 512:
35
+ self.max_length = 512
36
+
37
+ def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
38
+ """
39
+ L2-нормалізація ембедінгів
40
+
41
+ Args:
42
+ embeddings: матриця ембедінгів
43
+
44
+ Returns:
45
+ np.ndarray: нормалізована матриця ембедінгів
46
+ """
47
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
48
+ return embeddings / norms
49
+
50
+ def get_embeddings(self, texts: Union[str, List[str]]) -> np.ndarray:
51
+ """
52
+ Отримання ембедінгів для тексту або списку текстів
53
+
54
+ Args:
55
+ texts: текст або список текстів
56
+
57
+ Returns:
58
+ np.ndarray: матриця нормалізованих ембедінгів
59
+ """
60
+ if isinstance(texts, str):
61
+ texts = [texts]
62
+
63
+ all_embeddings = []
64
+
65
+ with torch.no_grad():
66
+ for i in range(0, len(texts), self.batch_size):
67
+ batch_texts = texts[i:i + self.batch_size]
68
+
69
+ # Токенізація
70
+ encoded = self.tokenizer.batch_encode_plus(
71
+ batch_texts,
72
+ padding=True,
73
+ truncation=True,
74
+ max_length=self.max_length,
75
+ return_tensors='pt'
76
+ )
77
+
78
+ # Переміщуємо тензори на потрібний пристрій
79
+ input_ids = encoded['input_ids'].to(self.device)
80
+ attention_mask = encoded['attention_mask'].to(self.device)
81
+
82
+ # Отримуємо ембедінги
83
+ outputs = self.model(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask
86
+ )
87
+
88
+ # Використовуємо [CLS] токен як ембедінг
89
+ embeddings = outputs.last_hidden_state[:, 0, :]
90
+ all_embeddings.append(embeddings.cpu().numpy())
91
+
92
+ # Об'єднуємо всі батчі
93
+ embeddings = np.vstack(all_embeddings)
94
+
95
+ # Нормалізуємо ембедінги
96
+ normalized_embeddings = self._normalize_embeddings(embeddings)
97
+
98
+ return normalized_embeddings
99
+
100
+ def get_model_info(self) -> Dict[str, any]:
101
+ """
102
+ Отримання інформації про модель
103
+
104
+ Returns:
105
+ Dict: інформація про модель
106
+ """
107
+ return {
108
+ 'model_name': self.model_name,
109
+ 'device': self.device,
110
+ 'embedding_size': self.model.config.hidden_size,
111
+ 'max_length': self.max_length,
112
+ 'batch_size': self.batch_size
113
+ }
requirements.txt CHANGED
@@ -3,4 +3,6 @@ openai
3
  pandas
4
  numpy
5
  python-dotenv
6
- scikit-learn
 
 
 
3
  pandas
4
  numpy
5
  python-dotenv
6
+ scikit-learn
7
+ torch
8
+ transformers
sdc_classifier.py CHANGED
@@ -2,21 +2,35 @@ import os
2
  import numpy as np
3
  import pandas as pd
4
  import json
5
- from typing import Dict, List
6
  from openai import OpenAI
7
  from pathlib import Path
8
  from embedding_cache import EmbeddingCache
9
 
10
  class SDCClassifier:
11
- def __init__(self, openai_api_key: str = None, cache_path: str = "embeddings_cache.db"):
 
 
 
 
12
  """
13
  Ініціалізація класифікатора SDC
14
 
15
  Args:
16
  openai_api_key: API ключ для OpenAI (опціонально, можна взяти з env)
17
  cache_path: шлях до файлу кешу ембедінгів
 
 
18
  """
19
  self.client = OpenAI(api_key=openai_api_key or os.getenv("OPENAI_API_KEY"))
 
 
 
 
 
 
 
 
20
  self.classes_json = {}
21
  self.class_signatures = None
22
  self.df = None
@@ -24,13 +38,12 @@ class SDCClassifier:
24
  self.embeddings_mean = None
25
  self.embeddings_std = None
26
 
27
- # Створення директорії для кешу, якщо потрібно
28
  cache_dir = os.path.dirname(cache_path)
29
  if cache_dir and not os.path.exists(cache_dir):
30
  os.makedirs(cache_dir)
31
 
32
  # Ініціалізація кешу
33
- from embedding_cache import EmbeddingCache
34
  self.cache = EmbeddingCache(cache_path)
35
 
36
  # Базовий стан
@@ -66,60 +79,23 @@ class SDCClassifier:
66
  self.classes_json = self.base_classes_json.copy()
67
  self.class_signatures = self.base_signatures.copy() if self.base_signatures else None
68
 
69
- def load_initial_state(self, classes_file: str, signatures_file: str) -> str:
70
  """
71
- Завантаження початкового стану при старті застосунку
72
 
73
  Args:
74
- classes_file: шлях до файлу з класами
75
- signatures_file: шлях до файлу з signatures
76
 
77
  Returns:
78
- str: повідомлення про результат завантаження
79
  """
80
  try:
81
- self.base_classes_json = self.load_classes(classes_file)
82
- if os.path.exists(signatures_file):
83
- self.base_signatures = self.load_signatures(signatures_file)
84
-
85
- # Встановлюємо поточний стан як базовий
86
- self.classes_json = self.base_classes_json.copy()
87
- self.class_signatures = self.base_signatures.copy() if self.base_signatures else None
88
-
89
- return f"Завантажено {len(self.base_classes_json)} базових класів"
90
- except Exception as e:
91
- return f"Помилка при завантаженні базового стану: {str(e)}"
92
-
93
- def restore_base_state(self) -> None:
94
- """Відновлення базового стану"""
95
- self.classes_json = self.base_classes_json.copy()
96
- self.class_signatures = self.base_signatures.copy() if self.base_signatures else None
97
-
98
- def load_initial_state(self, classes_file: str, signatures_file: str):
99
- """Завантаження початкового стану при старті застосунку"""
100
- self.base_classes_json = self.load_classes(classes_file)
101
- self.base_signatures = self.load_signatures(signatures_file)
102
-
103
- # Встановлюємо поточний стан як базовий
104
- self.classes_json = self.base_classes_json.copy()
105
- self.class_signatures = self.base_signatures.copy() if self.base_signatures else None
106
-
107
- def restore_base_state(self):
108
- """Відновлення базового стану"""
109
- self.classes_json = self.base_classes_json.copy()
110
- self.class_signatures = self.base_signatures.copy() if self.base_signatures else None
111
-
112
- def load_classes(self, json_path: str) -> dict:
113
- """Завантаження класів та їх хінтів з JSON файлу"""
114
- try:
115
- # Якщо передано вміст файлу замість шляху
116
  if isinstance(json_path, dict):
117
  self.classes_json = json_path
118
  else:
119
  with open(json_path, 'r', encoding='utf-8') as f:
120
  self.classes_json = json.load(f)
121
 
122
- # Валідація структури
123
  if not all(isinstance(hints, list) for hints in self.classes_json.values()):
124
  raise ValueError("Кожен клас повинен мати список хінтів")
125
 
@@ -132,12 +108,25 @@ class SDCClassifier:
132
  return {}
133
 
134
  def save_signatures(self, filename: str = "signatures.npz") -> None:
135
- """Зберігає signatures у NPZ файл"""
 
 
 
 
 
136
  if self.class_signatures:
137
  np.savez(filename, **self.class_signatures)
138
 
139
  def load_signatures(self, filename: str = "signatures.npz") -> Dict[str, np.ndarray]:
140
- """Завантажує signatures з NPZ файлу"""
 
 
 
 
 
 
 
 
141
  try:
142
  with np.load(filename) as data:
143
  self.class_signatures = {key: data[key] for key in data.files}
@@ -145,31 +134,34 @@ class SDCClassifier:
145
  except (FileNotFoundError, IOError):
146
  return None
147
 
148
- def get_openai_embedding(self, text: str, model_name: str = "text-embedding-3-large") -> list:
149
  """
150
- Отримання ембедінгу тексту через OpenAI API з використанням кешу
151
 
152
  Args:
153
  text: текст для ембедінгу
154
- model_name: назва моделі OpenAI
155
 
156
  Returns:
157
  list: ембедінг тексту
158
  """
159
- # Спроба отримати з кешу
160
- cached_embedding = self.cache.get(text, model_name)
161
  if cached_embedding is not None:
162
  return cached_embedding.tolist()
163
-
164
- # Якщо нема в кеші - отримуємо через API
165
- response = self.client.embeddings.create(
166
- input=text,
167
- model=model_name
168
- )
169
- embedding = response.data[0].embedding
 
 
 
170
 
171
  # Зберігаємо в кеш
172
- self.cache.put(text, model_name, embedding)
173
 
174
  return embedding
175
 
@@ -181,15 +173,24 @@ class SDCClassifier:
181
  """Очищення старих записів з кешу"""
182
  return self.cache.clear_old(days)
183
 
184
- def embed_hints(self, hint_list: List[str], model_name: str) -> np.ndarray:
185
- """Створення ембедінгів для списку хінтів"""
 
 
 
 
 
 
 
 
 
186
  emb_list = []
187
  total_hints = len(hint_list)
188
 
189
  for idx, hint in enumerate(hint_list, 1):
190
  try:
191
  print(f" Отримання embedding {idx}/{total_hints}: '{hint}'")
192
- emb = self.get_openai_embedding(hint, model_name=model_name)
193
  emb_list.append(emb)
194
  except Exception as e:
195
  print(f" Помилка при отриманні embedding для '{hint}': {str(e)}")
@@ -200,10 +201,10 @@ class SDCClassifier:
200
 
201
  return np.array(emb_list, dtype=np.float32)
202
 
203
-
204
- def initialize_signatures(self, model_name: str = "text-embedding-3-large",
205
- signatures_file: str = "signatures.npz",
206
- force_rebuild: bool = False) -> str:
207
  """
208
  Ініціалізує signatures: завантажує існуючі або створює нові
209
 
@@ -211,13 +212,16 @@ class SDCClassifier:
211
  model_name: назва моделі для ембедінгів
212
  signatures_file: шлях до файлу для збереження (None - не зберігати)
213
  force_rebuild: примусово перебудувати signatures
 
 
 
214
  """
215
  if not self.classes_json:
216
  return "Помилка: Не знайдено жодного класу в classes.json"
217
 
218
  print(f"Знайдено {len(self.classes_json)} класів")
219
 
220
- # Завантажуємо існуючі signatures, якщо є файл і не примусове оновлення
221
  if not force_rebuild and signatures_file and os.path.exists(signatures_file):
222
  try:
223
  loaded_signatures = self.load_signatures(signatures_file)
@@ -250,7 +254,7 @@ class SDCClassifier:
250
  if not self.class_signatures:
251
  return "Помилка: Не вдалося створити жодного signature"
252
 
253
- # Зберігаємо signatures тільки якщо вказано шлях до файлу
254
  if signatures_file:
255
  try:
256
  self.save_signatures(signatures_file)
@@ -262,8 +266,17 @@ class SDCClassifier:
262
  except Exception as e:
263
  return f"Помилка при створенні signatures: {str(e)}"
264
 
265
- def load_data(self, csv_path: str = "messages.csv", emb_path: str = "embeddings.npy"):
266
- """Завантаження даних з CSV та NPY файлів"""
 
 
 
 
 
 
 
 
 
267
  self.df = pd.read_csv(csv_path)
268
  emb_local = np.load(emb_path)
269
  assert len(self.df) == len(emb_local), "CSV і embeddings різної довжини!"
@@ -277,7 +290,16 @@ class SDCClassifier:
277
  return f"Завантажено {len(self.df)} рядків"
278
 
279
  def predict_classes(self, text_embedding: np.ndarray, threshold: float = 0.0) -> Dict[str, float]:
280
- """Передбачення класів для одного тексту"""
 
 
 
 
 
 
 
 
 
281
  results = {}
282
  for cls, sign in self.class_signatures.items():
283
  score = float(np.dot(text_embedding, sign))
@@ -287,11 +309,20 @@ class SDCClassifier:
287
  return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
288
 
289
  def process_single_text(self, text: str, threshold: float = 0.3) -> dict:
290
- """Обробка одного тексту"""
 
 
 
 
 
 
 
 
 
291
  if self.class_signatures is None:
292
  return {"error": "Спочатку збудуйте signatures!"}
293
 
294
- emb = self.get_openai_embedding(text)
295
 
296
  if self.embeddings_mean is not None and self.embeddings_std is not None:
297
  emb = (emb - self.embeddings_mean) / self.embeddings_std
@@ -310,17 +341,26 @@ class SDCClassifier:
310
  "result": "\n".join(formatted_results)
311
  }
312
 
313
- def classify_rows(self, filter_substring: str = "", threshold: float = 0.3):
314
- """Класифікація всіх або відфільтрованих рядків"""
 
 
 
 
 
 
 
 
 
315
  if self.class_signatures is None:
316
- return "Спочатку збудуйте signatures!"
317
 
318
  if self.df is None or self.embeddings is None:
319
- return "Дані не завантажені! Спочатку викличте load_data."
320
 
321
  if filter_substring:
322
  filtered_idx = self.df[self.df["Message"].str.contains(filter_substring,
323
- case=False,
324
  na=False)].index
325
  else:
326
  filtered_idx = self.df.index
@@ -345,9 +385,51 @@ class SDCClassifier:
345
  return result_df.reset_index(drop=True)
346
 
347
  def save_results(self, output_path: str = "messages_with_labels.csv") -> str:
348
- """Зберігання результатів класифікації"""
 
 
 
 
 
 
 
 
349
  if self.df is None:
350
  return "Дані відсутні!"
351
 
352
  self.df.to_csv(output_path, index=False)
353
- return f"Дані збережено у файл {output_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import json
5
+ from typing import Dict, List, Optional, Union
6
  from openai import OpenAI
7
  from pathlib import Path
8
  from embedding_cache import EmbeddingCache
9
 
10
  class SDCClassifier:
11
+ def __init__(self,
12
+ openai_api_key: str = None,
13
+ cache_path: str = "embeddings_cache.db",
14
+ local_model: str = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
15
+ device: str = None):
16
  """
17
  Ініціалізація класифікатора SDC
18
 
19
  Args:
20
  openai_api_key: API ключ для OpenAI (опціонально, можна взяти з env)
21
  cache_path: шлях до файлу кешу ембедінгів
22
+ local_model: назва або шлях до локальної моделі
23
+ device: пристрій для локальної моделі ('cuda', 'cpu' або None)
24
  """
25
  self.client = OpenAI(api_key=openai_api_key or os.getenv("OPENAI_API_KEY"))
26
+ self.local_embedder = None
27
+ self.using_local = False
28
+
29
+ if local_model:
30
+ from local_embedder import LocalEmbedder
31
+ self.local_embedder = LocalEmbedder(local_model, device)
32
+ self.using_local = True
33
+
34
  self.classes_json = {}
35
  self.class_signatures = None
36
  self.df = None
 
38
  self.embeddings_mean = None
39
  self.embeddings_std = None
40
 
41
+ # Створення директорії для кешу
42
  cache_dir = os.path.dirname(cache_path)
43
  if cache_dir and not os.path.exists(cache_dir):
44
  os.makedirs(cache_dir)
45
 
46
  # Ініціалізація кешу
 
47
  self.cache = EmbeddingCache(cache_path)
48
 
49
  # Базовий стан
 
79
  self.classes_json = self.base_classes_json.copy()
80
  self.class_signatures = self.base_signatures.copy() if self.base_signatures else None
81
 
82
+ def load_classes(self, json_path: Union[str, dict]) -> dict:
83
  """
84
+ Завантаження класів та їх хінтів з JSON файлу або словника
85
 
86
  Args:
87
+ json_path: шлях до JSON файлу або словник з класами
 
88
 
89
  Returns:
90
+ dict: словник класів та їх хінтів
91
  """
92
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if isinstance(json_path, dict):
94
  self.classes_json = json_path
95
  else:
96
  with open(json_path, 'r', encoding='utf-8') as f:
97
  self.classes_json = json.load(f)
98
 
 
99
  if not all(isinstance(hints, list) for hints in self.classes_json.values()):
100
  raise ValueError("Кожен клас повинен мати список хінтів")
101
 
 
108
  return {}
109
 
110
  def save_signatures(self, filename: str = "signatures.npz") -> None:
111
+ """
112
+ Зберігає signatures у NPZ файл
113
+
114
+ Args:
115
+ filename: шлях до файлу для збереження
116
+ """
117
  if self.class_signatures:
118
  np.savez(filename, **self.class_signatures)
119
 
120
  def load_signatures(self, filename: str = "signatures.npz") -> Dict[str, np.ndarray]:
121
+ """
122
+ Завантажує signatures з NPZ файлу
123
+
124
+ Args:
125
+ filename: шлях до файлу з signatures
126
+
127
+ Returns:
128
+ Dict[str, np.ndarray]: словник signatures
129
+ """
130
  try:
131
  with np.load(filename) as data:
132
  self.class_signatures = {key: data[key] for key in data.files}
 
134
  except (FileNotFoundError, IOError):
135
  return None
136
 
137
+ def get_embedding(self, text: str, model_name: str = None) -> list:
138
  """
139
+ Отримання ембедінгу тексту
140
 
141
  Args:
142
  text: текст для ембедінгу
143
+ model_name: назва моделі (OpenAI) або None для локальної
144
 
145
  Returns:
146
  list: ембедінг тексту
147
  """
148
+ # Перевіряємо кеш
149
+ cached_embedding = self.cache.get(text, model_name or "local")
150
  if cached_embedding is not None:
151
  return cached_embedding.tolist()
152
+
153
+ # Отримуємо ембедінг
154
+ if self.using_local and model_name is None:
155
+ embedding = self.local_embedder.get_embeddings(text)[0]
156
+ else:
157
+ response = self.client.embeddings.create(
158
+ input=text,
159
+ model=model_name or "text-embedding-3-large"
160
+ )
161
+ embedding = response.data[0].embedding
162
 
163
  # Зберігаємо в кеш
164
+ self.cache.put(text, model_name or "local", embedding)
165
 
166
  return embedding
167
 
 
173
  """Очищення старих записів з кешу"""
174
  return self.cache.clear_old(days)
175
 
176
+ def embed_hints(self, hint_list: List[str], model_name: str = None) -> np.ndarray:
177
+ """
178
+ Створення ембедінгів для списку хінтів
179
+
180
+ Args:
181
+ hint_list: список хінтів
182
+ model_name: назва моделі для ембедінгів
183
+
184
+ Returns:
185
+ np.ndarray: матриця ембедінгів
186
+ """
187
  emb_list = []
188
  total_hints = len(hint_list)
189
 
190
  for idx, hint in enumerate(hint_list, 1):
191
  try:
192
  print(f" Отримання embedding {idx}/{total_hints}: '{hint}'")
193
+ emb = self.get_embedding(hint, model_name=model_name)
194
  emb_list.append(emb)
195
  except Exception as e:
196
  print(f" Помилка при отриманні embedding для '{hint}': {str(e)}")
 
201
 
202
  return np.array(emb_list, dtype=np.float32)
203
 
204
+ def initialize_signatures(self,
205
+ model_name: str = None,
206
+ signatures_file: str = "signatures.npz",
207
+ force_rebuild: bool = False) -> str:
208
  """
209
  Ініціалізує signatures: завантажує існуючі або створює нові
210
 
 
212
  model_name: назва моделі для ембедінгів
213
  signatures_file: шлях до файлу для збереження (None - не зберігати)
214
  force_rebuild: примусово перебудувати signatures
215
+
216
+ Returns:
217
+ str: повідомлення про результат
218
  """
219
  if not self.classes_json:
220
  return "Помилка: Не знайдено жодного класу в classes.json"
221
 
222
  print(f"Знайдено {len(self.classes_json)} класів")
223
 
224
+ # Завантажуємо існуючі signatures
225
  if not force_rebuild and signatures_file and os.path.exists(signatures_file):
226
  try:
227
  loaded_signatures = self.load_signatures(signatures_file)
 
254
  if not self.class_signatures:
255
  return "Помилка: Не вдалося створити жодного signature"
256
 
257
+ # Зберігаємо signatures
258
  if signatures_file:
259
  try:
260
  self.save_signatures(signatures_file)
 
266
  except Exception as e:
267
  return f"Помилка при створенні signatures: {str(e)}"
268
 
269
+ def load_data(self, csv_path: str = "messages.csv", emb_path: str = "embeddings.npy") -> str:
270
+ """
271
+ Завантаження даних з CSV та NPY файлів
272
+
273
+ Args:
274
+ csv_path: шлях до CSV файлу
275
+ emb_path: шлях до NPY файлу з ембедінгами
276
+
277
+ Returns:
278
+ str: повідомлення про результат
279
+ """
280
  self.df = pd.read_csv(csv_path)
281
  emb_local = np.load(emb_path)
282
  assert len(self.df) == len(emb_local), "CSV і embeddings різної довжини!"
 
290
  return f"Завантажено {len(self.df)} рядків"
291
 
292
  def predict_classes(self, text_embedding: np.ndarray, threshold: float = 0.0) -> Dict[str, float]:
293
+ """
294
+ Передбачення класів для одного тексту
295
+
296
+ Args:
297
+ text_embedding: ембедінг тексту
298
+ threshold: поріг впевненості
299
+
300
+ Returns:
301
+ Dict[str, float]: словник класів та їх scores
302
+ """
303
  results = {}
304
  for cls, sign in self.class_signatures.items():
305
  score = float(np.dot(text_embedding, sign))
 
309
  return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
310
 
311
  def process_single_text(self, text: str, threshold: float = 0.3) -> dict:
312
+ """
313
+ Обробка одного тексту
314
+
315
+ Args:
316
+ text: текст для класифікації
317
+ threshold: поріг впевненості
318
+
319
+ Returns:
320
+ dict: результати класифікації
321
+ """
322
  if self.class_signatures is None:
323
  return {"error": "Спочатку збудуйте signatures!"}
324
 
325
+ emb = self.get_embedding(text)
326
 
327
  if self.embeddings_mean is not None and self.embeddings_std is not None:
328
  emb = (emb - self.embeddings_mean) / self.embeddings_std
 
341
  "result": "\n".join(formatted_results)
342
  }
343
 
344
+ def classify_rows(self, filter_substring: str = "", threshold: float = 0.3) -> pd.DataFrame:
345
+ """
346
+ Класифікація всіх або відфільтрованих рядків
347
+
348
+ Args:
349
+ filter_substring: підрядок для фільтрації
350
+ threshold: поріг впевненості
351
+
352
+ Returns:
353
+ pd.DataFrame: результати класифікації
354
+ """
355
  if self.class_signatures is None:
356
+ raise ValueError("Спочатку збудуйте signatures!")
357
 
358
  if self.df is None or self.embeddings is None:
359
+ raise ValueError("Дані не завантажені! Спочатку викличте load_data.")
360
 
361
  if filter_substring:
362
  filtered_idx = self.df[self.df["Message"].str.contains(filter_substring,
363
+ case=False,
364
  na=False)].index
365
  else:
366
  filtered_idx = self.df.index
 
385
  return result_df.reset_index(drop=True)
386
 
387
  def save_results(self, output_path: str = "messages_with_labels.csv") -> str:
388
+ """
389
+ Зберігання результатів класифікації
390
+
391
+ Args:
392
+ output_path: шлях для збереження результатів
393
+
394
+ Returns:
395
+ str: повідомлення про результат
396
+ """
397
  if self.df is None:
398
  return "Дані відсутні!"
399
 
400
  self.df.to_csv(output_path, index=False)
401
+ return f"Дані збережено у файл {output_path}"
402
+
403
+ def save_model_info(self, path: str = "model_info.json") -> None:
404
+ """
405
+ Зберігання інформації про поточний стан моделі
406
+
407
+ Args:
408
+ path: шлях для збереження
409
+ """
410
+ info = {
411
+ "using_local": self.using_local,
412
+ "classes_count": len(self.classes_json),
413
+ "signatures_count": len(self.class_signatures) if self.class_signatures else 0,
414
+ "cache_stats": self.get_cache_stats(),
415
+ }
416
+
417
+ if self.using_local:
418
+ info["local_model"] = self.local_embedder.get_model_info()
419
+
420
+ with open(path, 'w', encoding='utf-8') as f:
421
+ json.dump(info, f, indent=2)
422
+
423
+ @staticmethod
424
+ def load_model_info(path: str) -> dict:
425
+ """
426
+ Завантаження інформації про модель
427
+
428
+ Args:
429
+ path: шлях до файлу з інформацією
430
+
431
+ Returns:
432
+ dict: інформація про модель
433
+ """
434
+ with open(path, 'r', encoding='utf-8') as f:
435
+ return json.load(f)