Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import os
|
|
7 |
import asyncio
|
8 |
import nltk
|
9 |
from nltk.tokenize import sent_tokenize
|
|
|
10 |
|
11 |
# Téléchargement de punkt_tab avec gestion d'erreur
|
12 |
try:
|
@@ -20,15 +21,23 @@ if not HF_TOKEN:
|
|
20 |
st.error("Erreur : Clé API Hugging Face (HF_TOKEN) manquante. Veuillez configurer HF_TOKEN dans les variables d'environnement.")
|
21 |
st.stop()
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
|
|
26 |
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Fonction pour appeler l'API Zephyr avec des paramètres ajustés
|
34 |
async def call_zephyr_api(prompt, mode, hf_token=HF_TOKEN):
|
@@ -55,7 +64,7 @@ def safe_translate_to_fr(text, max_length=512):
|
|
55 |
sentences = sent_tokenize(text)
|
56 |
translated_sentences = []
|
57 |
for sentence in sentences:
|
58 |
-
translated =
|
59 |
translated_sentences.append(translated)
|
60 |
return " ".join(translated_sentences)
|
61 |
except Exception as e:
|
@@ -110,13 +119,13 @@ async def full_analysis(text, mode, detail_mode, history):
|
|
110 |
progress_bar.progress(25)
|
111 |
|
112 |
if lang != "en":
|
113 |
-
text_en =
|
114 |
else:
|
115 |
text_en = text
|
116 |
|
117 |
# Étape 2 : Analyse du sentiment
|
118 |
status_text.write("Analyse en cours... (Étape 2 : Analyse du sentiment)")
|
119 |
-
result =
|
120 |
result = result[0]
|
121 |
sentiment_output = f"Sentiment prédictif : {result['label']} (Score: {result['score']:.2f})"
|
122 |
sentiment_gauge = create_sentiment_gauge(result['label'], result['score'])
|
|
|
7 |
import asyncio
|
8 |
import nltk
|
9 |
from nltk.tokenize import sent_tokenize
|
10 |
+
import torch
|
11 |
|
12 |
# Téléchargement de punkt_tab avec gestion d'erreur
|
13 |
try:
|
|
|
21 |
st.error("Erreur : Clé API Hugging Face (HF_TOKEN) manquante. Veuillez configurer HF_TOKEN dans les variables d'environnement.")
|
22 |
st.stop()
|
23 |
|
24 |
+
# Fonctions pour charger les modèles avec st.cache_resource
|
25 |
+
@st.cache_resource
|
26 |
+
def load_classifier():
|
27 |
+
return pipeline("sentiment-analysis", model="mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", device="cpu", map_location="cpu")
|
28 |
|
29 |
+
@st.cache_resource
|
30 |
+
def load_translator_to_en():
|
31 |
+
return pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en", device="cpu", map_location="cpu")
|
32 |
|
33 |
+
@st.cache_resource
|
34 |
+
def load_translator_to_fr():
|
35 |
+
return pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr", device="cpu", map_location="cpu")
|
36 |
+
|
37 |
+
# Charger les modèles une seule fois
|
38 |
+
classifier = load_classifier()
|
39 |
+
translator_to_en = load_translator_to_en()
|
40 |
+
translator_to_fr = load_translator_to_fr()
|
41 |
|
42 |
# Fonction pour appeler l'API Zephyr avec des paramètres ajustés
|
43 |
async def call_zephyr_api(prompt, mode, hf_token=HF_TOKEN):
|
|
|
64 |
sentences = sent_tokenize(text)
|
65 |
translated_sentences = []
|
66 |
for sentence in sentences:
|
67 |
+
translated = translator_to_fr(sentence, max_length=max_length)[0]['translation_text']
|
68 |
translated_sentences.append(translated)
|
69 |
return " ".join(translated_sentences)
|
70 |
except Exception as e:
|
|
|
119 |
progress_bar.progress(25)
|
120 |
|
121 |
if lang != "en":
|
122 |
+
text_en = translator_to_en(text, max_length=512)[0]['translation_text']
|
123 |
else:
|
124 |
text_en = text
|
125 |
|
126 |
# Étape 2 : Analyse du sentiment
|
127 |
status_text.write("Analyse en cours... (Étape 2 : Analyse du sentiment)")
|
128 |
+
result = classifier(text_en)
|
129 |
result = result[0]
|
130 |
sentiment_output = f"Sentiment prédictif : {result['label']} (Score: {result['score']:.2f})"
|
131 |
sentiment_gauge = create_sentiment_gauge(result['label'], result['score'])
|