DinoFrog commited on
Commit
a56dc0f
·
verified ·
1 Parent(s): 813fe83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -7,6 +7,7 @@ import os
7
  import asyncio
8
  import nltk
9
  from nltk.tokenize import sent_tokenize
 
10
 
11
  # Téléchargement de punkt_tab avec gestion d'erreur
12
  try:
@@ -20,15 +21,23 @@ if not HF_TOKEN:
20
  st.error("Erreur : Clé API Hugging Face (HF_TOKEN) manquante. Veuillez configurer HF_TOKEN dans les variables d'environnement.")
21
  st.stop()
22
 
23
- # Initialisation des modèles dans st.session_state pour éviter le rechargement
24
- if 'classifier' not in st.session_state:
25
- st.session_state.classifier = pipeline("sentiment-analysis", model="mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", device="cpu")
 
26
 
27
- if 'translator_to_en' not in st.session_state:
28
- st.session_state.translator_to_en = pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en", device="cpu")
 
29
 
30
- if 'translator_to_fr' not in st.session_state:
31
- st.session_state.translator_to_fr = pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr", device="cpu")
 
 
 
 
 
 
32
 
33
  # Fonction pour appeler l'API Zephyr avec des paramètres ajustés
34
  async def call_zephyr_api(prompt, mode, hf_token=HF_TOKEN):
@@ -55,7 +64,7 @@ def safe_translate_to_fr(text, max_length=512):
55
  sentences = sent_tokenize(text)
56
  translated_sentences = []
57
  for sentence in sentences:
58
- translated = st.session_state.translator_to_fr(sentence, max_length=max_length)[0]['translation_text']
59
  translated_sentences.append(translated)
60
  return " ".join(translated_sentences)
61
  except Exception as e:
@@ -110,13 +119,13 @@ async def full_analysis(text, mode, detail_mode, history):
110
  progress_bar.progress(25)
111
 
112
  if lang != "en":
113
- text_en = st.session_state.translator_to_en(text, max_length=512)[0]['translation_text']
114
  else:
115
  text_en = text
116
 
117
  # Étape 2 : Analyse du sentiment
118
  status_text.write("Analyse en cours... (Étape 2 : Analyse du sentiment)")
119
- result = st.session_state.classifier(text_en) # Utilisation du modèle depuis st.session_state
120
  result = result[0]
121
  sentiment_output = f"Sentiment prédictif : {result['label']} (Score: {result['score']:.2f})"
122
  sentiment_gauge = create_sentiment_gauge(result['label'], result['score'])
 
7
  import asyncio
8
  import nltk
9
  from nltk.tokenize import sent_tokenize
10
+ import torch
11
 
12
  # Téléchargement de punkt_tab avec gestion d'erreur
13
  try:
 
21
  st.error("Erreur : Clé API Hugging Face (HF_TOKEN) manquante. Veuillez configurer HF_TOKEN dans les variables d'environnement.")
22
  st.stop()
23
 
24
+ # Fonctions pour charger les modèles avec st.cache_resource
25
+ @st.cache_resource
26
+ def load_classifier():
27
+ return pipeline("sentiment-analysis", model="mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", device="cpu", map_location="cpu")
28
 
29
+ @st.cache_resource
30
+ def load_translator_to_en():
31
+ return pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en", device="cpu", map_location="cpu")
32
 
33
+ @st.cache_resource
34
+ def load_translator_to_fr():
35
+ return pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr", device="cpu", map_location="cpu")
36
+
37
+ # Charger les modèles une seule fois
38
+ classifier = load_classifier()
39
+ translator_to_en = load_translator_to_en()
40
+ translator_to_fr = load_translator_to_fr()
41
 
42
  # Fonction pour appeler l'API Zephyr avec des paramètres ajustés
43
  async def call_zephyr_api(prompt, mode, hf_token=HF_TOKEN):
 
64
  sentences = sent_tokenize(text)
65
  translated_sentences = []
66
  for sentence in sentences:
67
+ translated = translator_to_fr(sentence, max_length=max_length)[0]['translation_text']
68
  translated_sentences.append(translated)
69
  return " ".join(translated_sentences)
70
  except Exception as e:
 
119
  progress_bar.progress(25)
120
 
121
  if lang != "en":
122
+ text_en = translator_to_en(text, max_length=512)[0]['translation_text']
123
  else:
124
  text_en = text
125
 
126
  # Étape 2 : Analyse du sentiment
127
  status_text.write("Analyse en cours... (Étape 2 : Analyse du sentiment)")
128
+ result = classifier(text_en)
129
  result = result[0]
130
  sentiment_output = f"Sentiment prédictif : {result['label']} (Score: {result['score']:.2f})"
131
  sentiment_gauge = create_sentiment_gauge(result['label'], result['score'])