Spaces:
Runtime error
Runtime error
import gradio as gr | |
from sdc_classifier import SDCClassifier | |
from dotenv import load_dotenv | |
import json | |
# Load environment variables | |
load_dotenv() | |
def main(): | |
# Ініціалізуємо класифікатор | |
classifier = SDCClassifier() | |
# Спроба завантажити початкові класи та signatures | |
DEFAULT_CLASSES_FILE = "classes_short.json" | |
DEFAULT_SIGNATURES_FILE = "signatures.npz" | |
print("Завантаження початкових класів...") | |
classifier.load_initial_state(DEFAULT_CLASSES_FILE, DEFAULT_SIGNATURES_FILE) | |
with gr.Blocks() as demo: | |
gr.Markdown("# SDC Classifier з Gradio") | |
with gr.Tabs(): | |
# Вкладка 1: Single Text Testing | |
with gr.TabItem("Тестування одного тексту"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Введіть текст для аналізу", | |
lines=5, | |
placeholder="Введіть текст..." | |
) | |
threshold_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Поріг впевненості" | |
) | |
single_process_btn = gr.Button("Проаналізувати") | |
with gr.Column(): | |
result_text = gr.JSON(label="Результати аналізу") | |
# Налаштування моделі | |
with gr.Accordion("Налаштування моделі", open=False): | |
with gr.Row(): | |
model_choice = gr.Dropdown( | |
choices=["text-embedding-3-large","text-embedding-3-small"], | |
value="text-embedding-3-large", | |
label="OpenAI model" | |
) | |
json_file = gr.File( | |
label="Завантажити новий JSON з класами", | |
file_types=[".json"] | |
) | |
force_rebuild = gr.Checkbox( | |
label="Примусово перебудувати signatures", | |
value=False | |
) | |
with gr.Row(): | |
build_btn = gr.Button("Оновити signatures") | |
build_out = gr.Label(label="Статус signatures") | |
cache_stats = gr.JSON(label="Статистика кешу", value={}) | |
# Вкладка 2: Batch Processing | |
with gr.TabItem("Пакетна обробка"): | |
gr.Markdown("## 1) Завантаження даних") | |
with gr.Row(): | |
csv_input = gr.Textbox( | |
value="messages.csv", | |
label="CSV-файл" | |
) | |
emb_input = gr.Textbox( | |
value="embeddings.npy", | |
label="Numpy Embeddings" | |
) | |
load_btn = gr.Button("Завантажити дані") | |
load_output = gr.Label(label="Результат завантаження") | |
gr.Markdown("## 2) Класифікація") | |
with gr.Row(): | |
filter_in = gr.Textbox(label="Фільтр (опціонально)") | |
batch_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Поріг впевненості" | |
) | |
classify_btn = gr.Button("Класифікувати") | |
classify_out = gr.Dataframe(label="Результат (Message / Target / Scores)") | |
gr.Markdown("## 3) Зберегти результати") | |
save_btn = gr.Button("Зберегти розмічені дані") | |
save_out = gr.Label() | |
gr.Markdown(""" | |
### Інструкція: | |
1. У вкладці "Налаштування моделі" можна: | |
- Завантажити новий JSON файл з класами | |
- Вибрати модель для embeddings | |
- Примусово перебудувати signatures | |
2. Після зміни класів натисніть "Оновити signatures" | |
3. Використовуйте повзунок "Поріг впевненості" для фільтрації результатів | |
4. На вкладці "Пакетна обробка" можна аналізувати багато повідомлень | |
5. Результати можна зберегти в CSV файл | |
""") | |
# Підключення обробників подій | |
def update_with_file(file, model_name, force): | |
if file is None: | |
# Відновлюємо базовий стан якщо файл видалено | |
classifier.restore_base_state() | |
return ("Відновлено базовий набір класів", classifier.get_cache_stats()) | |
try: | |
# Для роботи з gradio File компонентом | |
if hasattr(file, 'name'): # Якщо це файловий об'єкт | |
with open(file.name, 'r', encoding='utf-8') as f: | |
new_classes = json.load(f) | |
else: # Якщо це строка | |
new_classes = json.loads(file) | |
if not isinstance(new_classes, dict): | |
return ("Помилка: JSON повинен містити словник класів", classifier.get_cache_stats()) | |
# Завантажуємо нові класи без перезапису файлу за замовчуванням | |
classifier.load_classes(new_classes) | |
# Створюємо тимчасові signatures | |
result = classifier.initialize_signatures( | |
model_name=model_name, | |
signatures_file=None, # Не зберігаємо у файл | |
force_rebuild=True # Завжди перебудовуємо для нових класів | |
) | |
return (f"Тимчасові класи завантажено. {result}", classifier.get_cache_stats()) | |
except json.JSONDecodeError: | |
return ("Помилка: Неправильний формат JSON файлу", classifier.get_cache_stats()) | |
except Exception as e: | |
return (f"Помилка при оновленні: {str(e)}", classifier.get_cache_stats()) | |
single_process_btn.click( | |
fn=lambda text, threshold: classifier.process_single_text(text, threshold), | |
inputs=[text_input, threshold_slider], | |
outputs=result_text | |
) | |
build_btn.click( | |
fn=update_with_file, | |
inputs=[json_file, model_choice, force_rebuild], | |
outputs=[build_out, cache_stats] | |
) | |
load_btn.click( | |
fn=lambda csv, emb: classifier.load_data(csv, emb), | |
inputs=[csv_input, emb_input], | |
outputs=load_output | |
) | |
classify_btn.click( | |
fn=lambda filter_str, threshold: classifier.classify_rows(filter_str, threshold), | |
inputs=[filter_in, batch_threshold], | |
outputs=classify_out | |
) | |
save_btn.click( | |
fn=lambda: classifier.save_results("messages_with_labels.csv"), | |
inputs=[], | |
outputs=save_out | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
if __name__ == "__main__": | |
main() |