Spaces:
Running
Running
Upload 2 files
Browse filesDocs donde se corren los test de GPT2 y Gemma3
- gemma3test.py +272 -0
- gpt2test.py +332 -0
gemma3test.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import pandas as pd
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from supabase import create_client, Client
|
7 |
+
from transformers import pipeline
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
import numpy as np
|
11 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
+
import re
|
13 |
+
|
14 |
+
|
15 |
+
# ---------------------------------------------------------------------------------
|
16 |
+
# Funciones auxiliares
|
17 |
+
# ---------------------------------------------------------------------------------
|
18 |
+
|
19 |
+
def extract_country_and_dates(prompt, countries):
|
20 |
+
country = None
|
21 |
+
start_date = None
|
22 |
+
end_date = None
|
23 |
+
|
24 |
+
# Buscar el país (insensible a mayúsculas y minúsculas)
|
25 |
+
for c in countries:
|
26 |
+
if re.search(r'\b' + re.escape(c) + r'\b', prompt, re.IGNORECASE):
|
27 |
+
country = c
|
28 |
+
break
|
29 |
+
|
30 |
+
# Buscar rangos de años con diferentes separadores (-, to, until, from ... to, between ... and)
|
31 |
+
date_ranges = re.findall(r'(\d{4})\s*(?:-|to|until|from.*?to|between.*?and)\s*(\d{4})', prompt, re.IGNORECASE)
|
32 |
+
if date_ranges:
|
33 |
+
start_date = date_ranges[0][0]
|
34 |
+
end_date = date_ranges[0][1]
|
35 |
+
else:
|
36 |
+
# Buscar un solo año
|
37 |
+
single_years = re.findall(r'\b(\d{4})\b', prompt)
|
38 |
+
if single_years:
|
39 |
+
start_date = single_years[0]
|
40 |
+
end_date = single_years[0]
|
41 |
+
|
42 |
+
return country, start_date, end_date
|
43 |
+
|
44 |
+
|
45 |
+
def generate_plotly_graph(df, user_query, country=None, start_date=None, end_date=None):
|
46 |
+
relevant_data = df.copy()
|
47 |
+
|
48 |
+
if 'geo' in relevant_data.columns and country:
|
49 |
+
relevant_data = relevant_data[relevant_data['geo'].str.lower() == country.lower()]
|
50 |
+
|
51 |
+
if 'year' in relevant_data.columns:
|
52 |
+
relevant_data['year'] = pd.to_numeric(relevant_data['year'], errors='coerce').dropna().astype(int)
|
53 |
+
if start_date and end_date:
|
54 |
+
relevant_data = relevant_data[
|
55 |
+
(relevant_data['year'] >= int(start_date)) & (relevant_data['year'] <= int(end_date))
|
56 |
+
]
|
57 |
+
elif start_date:
|
58 |
+
relevant_data = relevant_data[relevant_data['year'] >= int(start_date)]
|
59 |
+
elif end_date:
|
60 |
+
relevant_data = relevant_data[relevant_data['year'] <= int(end_date)]
|
61 |
+
|
62 |
+
numeric_cols = relevant_data.select_dtypes(include=['number']).columns.tolist()
|
63 |
+
if 'year' in relevant_data.columns and numeric_cols:
|
64 |
+
fig = go.Figure()
|
65 |
+
for col in numeric_cols:
|
66 |
+
if col != 'year':
|
67 |
+
fig.add_trace(go.Scatter(x=relevant_data['year'], y=relevant_data[col], mode='lines+markers', name=col))
|
68 |
+
|
69 |
+
title = f"Data for {country if country else 'All Regions'}"
|
70 |
+
if start_date and end_date:
|
71 |
+
title += f" ({start_date}-{end_date})"
|
72 |
+
elif start_date:
|
73 |
+
title += f" (from {start_date})"
|
74 |
+
elif end_date:
|
75 |
+
title += f" (up to {end_date})"
|
76 |
+
|
77 |
+
# Añadir título y etiquetas de los ejes
|
78 |
+
fig.update_layout(
|
79 |
+
title=title,
|
80 |
+
xaxis_title="Year",
|
81 |
+
yaxis_title="Value" # Necesitaremos inferir o tener nombres de columnas más descriptivos
|
82 |
+
)
|
83 |
+
return fig
|
84 |
+
else:
|
85 |
+
return None
|
86 |
+
|
87 |
+
|
88 |
+
# ---------------------------------------------------------------------------------
|
89 |
+
# Configuración de conexión a Supabase
|
90 |
+
# ---------------------------------------------------------------------------------
|
91 |
+
load_dotenv()
|
92 |
+
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
93 |
+
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
94 |
+
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
95 |
+
|
96 |
+
|
97 |
+
# Función para cargar datos de una tabla de Supabase
|
98 |
+
def load_data(table):
|
99 |
+
try:
|
100 |
+
if supabase:
|
101 |
+
response = supabase.from_(table).select("*").execute()
|
102 |
+
if hasattr(response, 'data'):
|
103 |
+
return pd.DataFrame(response.data)
|
104 |
+
elif hasattr(response, '_error'):
|
105 |
+
st.error(f"Error fetching data: {response._error}")
|
106 |
+
return pd.DataFrame()
|
107 |
+
else:
|
108 |
+
st.info("Response object does not have 'data' or known error attributes. Check the logs.")
|
109 |
+
return pd.DataFrame()
|
110 |
+
else:
|
111 |
+
st.error("Supabase client not initialized. Check environment variables.")
|
112 |
+
return pd.DataFrame()
|
113 |
+
except Exception as e:
|
114 |
+
st.error(f"An error occurred during data loading: {e}")
|
115 |
+
return pd.DataFrame()
|
116 |
+
|
117 |
+
|
118 |
+
# ---------------------------------------------------------------------------------
|
119 |
+
# Cargar datos iniciales
|
120 |
+
# ---------------------------------------------------------------------------------
|
121 |
+
labor_data = load_data("labor")
|
122 |
+
fertility_data = load_data("fertility")
|
123 |
+
|
124 |
+
# ---------------------------------------------------------------------------------
|
125 |
+
# Inicialización de modelos para RAG
|
126 |
+
# ---------------------------------------------------------------------------------
|
127 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
128 |
+
llm_pipeline = pipeline("text-generation", model="google/gemma-3-1b-it", token=os.getenv("HF_TOKEN"))
|
129 |
+
|
130 |
+
# ---------------------------------------------------------------------------------
|
131 |
+
# Generación de Embeddings y Metadatos (en memoria)
|
132 |
+
# ---------------------------------------------------------------------------------
|
133 |
+
embeddings_list = []
|
134 |
+
contents_list = []
|
135 |
+
metadatas_list = []
|
136 |
+
ids_list = []
|
137 |
+
|
138 |
+
for index, row in labor_data.iterrows():
|
139 |
+
doc = f"Country: {row['geo']}, Year: {row['year']}, Employment Rate: {row['labour_force'] if 'labour_force' in row else 'N/A'}"
|
140 |
+
embeddings_list.append(embedding_model.encode(doc))
|
141 |
+
contents_list.append(doc)
|
142 |
+
metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'labor'})
|
143 |
+
ids_list.append(f"labor_{index}")
|
144 |
+
|
145 |
+
for index, row in fertility_data.iterrows():
|
146 |
+
doc = f"Country: {row['geo']}, Year: {row['year']}, Fertility Rate: {row['fertility_rate'] if 'fertility_rate' in row else 'N/A'}"
|
147 |
+
embeddings_list.append(embedding_model.encode(doc))
|
148 |
+
contents_list.append(doc)
|
149 |
+
metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'fertility'})
|
150 |
+
ids_list.append(f"fertility_{index}")
|
151 |
+
|
152 |
+
embeddings_array = np.array(embeddings_list)
|
153 |
+
|
154 |
+
|
155 |
+
# ---------------------------------------------------------------------------------
|
156 |
+
# Función para recuperar documentos relevantes (en memoria)
|
157 |
+
# ---------------------------------------------------------------------------------
|
158 |
+
def retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents, top_k=3):
|
159 |
+
similarities = cosine_similarity([query_embedding], stored_embeddings)[0]
|
160 |
+
sorted_indices = np.argsort(similarities)[::-1]
|
161 |
+
relevant_documents = [contents[i] for i in sorted_indices[:top_k]]
|
162 |
+
return relevant_documents
|
163 |
+
|
164 |
+
|
165 |
+
# ---------------------------------------------------------------------------------
|
166 |
+
# Generación de la explicación usando RAG
|
167 |
+
# ---------------------------------------------------------------------------------
|
168 |
+
def generate_rag_explanation(user_query, stored_embeddings, contents):
|
169 |
+
query_embedding = embedding_model.encode(user_query)
|
170 |
+
relevant_docs = retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents)
|
171 |
+
if relevant_docs:
|
172 |
+
context = "\n".join(relevant_docs)
|
173 |
+
augmented_prompt = f"Based on the following information:\n\n{context}\n\nAnswer the question related to: {user_query}"
|
174 |
+
output = llm_pipeline(augmented_prompt, max_length=250, num_return_sequences=1)
|
175 |
+
return output[0]['generated_text']
|
176 |
+
else:
|
177 |
+
return "No relevant information found to answer your query."
|
178 |
+
|
179 |
+
|
180 |
+
# ---------------------------------------------------------------------------------
|
181 |
+
# Generar la lista de países automáticamente
|
182 |
+
# ---------------------------------------------------------------------------------
|
183 |
+
available_countries_labor = labor_data['geo'].unique().tolist() if 'geo' in labor_data.columns else []
|
184 |
+
available_countries_fertility = fertility_data['geo'].unique().tolist() if 'geo' in fertility_data.columns else []
|
185 |
+
all_countries = list(set(available_countries_labor + available_countries_fertility))
|
186 |
+
|
187 |
+
# ---------------------------------------------------------------------------------
|
188 |
+
# Configuración de la app en Streamlit
|
189 |
+
# ---------------------------------------------------------------------------------
|
190 |
+
st.set_page_config(page_title="GraphGen", page_icon="🇪🇺")
|
191 |
+
st.title("_Europe GraphGen_ :blue[Graph generator] :flag-eu:")
|
192 |
+
st.caption("Mapping Europe's data with insights")
|
193 |
+
|
194 |
+
if "messages" not in st.session_state:
|
195 |
+
st.session_state.messages = []
|
196 |
+
st.session_state.messages.append({"role": "assistant", "content": "What graphic and insights do you need?"})
|
197 |
+
|
198 |
+
for message in st.session_state.messages:
|
199 |
+
with st.chat_message(message["role"]):
|
200 |
+
st.markdown(message["content"])
|
201 |
+
|
202 |
+
prompt = st.chat_input("Type your message here...", key="chat_input_bottom")
|
203 |
+
|
204 |
+
if prompt:
|
205 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
206 |
+
with st.chat_message("user"):
|
207 |
+
st.markdown(prompt)
|
208 |
+
|
209 |
+
with st.spinner('Generating answer...'):
|
210 |
+
|
211 |
+
try:
|
212 |
+
# Determinar el año más reciente en los datos
|
213 |
+
latest_year_labor = labor_data['year'].max() if 'year' in labor_data else datetime.now().year
|
214 |
+
latest_year_fertility = fertility_data['year'].max() if 'year' in fertility_data else datetime.now().year
|
215 |
+
latest_year = max(latest_year_labor, latest_year_fertility, datetime.now().year)
|
216 |
+
|
217 |
+
country, start_date, end_date = extract_country_and_dates(prompt, all_countries, latest_year)
|
218 |
+
graph_displayed = False
|
219 |
+
|
220 |
+
# Analizar el prompt para determinar la intención del usuario
|
221 |
+
if re.search(r'\b(labor|employment|job|workforce)\b', prompt, re.IGNORECASE):
|
222 |
+
# Generar gráfica de datos laborales
|
223 |
+
labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date)
|
224 |
+
if labor_fig:
|
225 |
+
st.session_state.messages.append(
|
226 |
+
{"role": "assistant", "content": "Here is the labor data graphic:"})
|
227 |
+
with st.chat_message("assistant"):
|
228 |
+
st.plotly_chart(labor_fig)
|
229 |
+
graph_displayed = True
|
230 |
+
elif re.search(r'\b(fertility|birth|population growth)\b', prompt, re.IGNORECASE):
|
231 |
+
# Generar gráfica de datos de fertilidad
|
232 |
+
fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date)
|
233 |
+
if fertility_fig:
|
234 |
+
st.session_state.messages.append(
|
235 |
+
{"role": "assistant", "content": "Here is the fertility data graphic:"})
|
236 |
+
with st.chat_message("assistant"):
|
237 |
+
st.plotly_chart(fertility_fig)
|
238 |
+
graph_displayed = True
|
239 |
+
else:
|
240 |
+
# Si no se identifica una intención clara, intentar mostrar la gráfica de datos laborales primero
|
241 |
+
labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date)
|
242 |
+
if labor_fig:
|
243 |
+
st.session_state.messages.append(
|
244 |
+
{"role": "assistant", "content": "Here is the labor data graphic:"})
|
245 |
+
with st.chat_message("assistant"):
|
246 |
+
st.plotly_chart(labor_fig)
|
247 |
+
graph_displayed = True
|
248 |
+
elif not graph_displayed:
|
249 |
+
fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date)
|
250 |
+
if fertility_fig:
|
251 |
+
st.session_state.messages.append(
|
252 |
+
{"role": "assistant", "content": "Here is the fertility data graphic:"})
|
253 |
+
with st.chat_message("assistant"):
|
254 |
+
st.plotly_chart(fertility_fig)
|
255 |
+
graph_displayed = True
|
256 |
+
|
257 |
+
# Generar explicación usando RAG
|
258 |
+
explanation = generate_rag_explanation(prompt, embeddings_array, contents_list)
|
259 |
+
st.session_state.messages.append({"role": "assistant", "content": f"Explanation: {explanation}"})
|
260 |
+
with st.chat_message("assistant"):
|
261 |
+
st.markdown(f"**Explanation:** {explanation}")
|
262 |
+
|
263 |
+
except Exception as e:
|
264 |
+
st.session_state.messages.append({"role": "assistant", "content": f"Error generating answer: {e}"})
|
265 |
+
with st.chat_message("assistant"):
|
266 |
+
st.error(f"Error generating answer: {e}")
|
267 |
+
|
268 |
+
if st.button("Clear chat"):
|
269 |
+
st.session_state.messages = []
|
270 |
+
st.session_state.messages.append(
|
271 |
+
{"role": "assistant", "content": "Chat has been cleared. What graphic and insights do you need now?"})
|
272 |
+
st.rerun()
|
gpt2test.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import pandas as pd
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from supabase import create_client, Client
|
7 |
+
from transformers import pipeline
|
8 |
+
import plotly.express as px
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
import time
|
11 |
+
|
12 |
+
# ---------------------------------------------------------------------------------
|
13 |
+
# Supabase Setup
|
14 |
+
# ---------------------------------------------------------------------------------
|
15 |
+
load_dotenv()
|
16 |
+
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
17 |
+
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
|
18 |
+
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
19 |
+
|
20 |
+
# ---------------------------------------------------------------------------------
|
21 |
+
# Data Loading Function
|
22 |
+
# ---------------------------------------------------------------------------------
|
23 |
+
def load_data(table):
|
24 |
+
try:
|
25 |
+
if supabase:
|
26 |
+
response = supabase.from_(table).select("*").execute()
|
27 |
+
if hasattr(response, 'data'):
|
28 |
+
return pd.DataFrame(response.data)
|
29 |
+
else:
|
30 |
+
st.error(f"Error fetching data or no data returned for table '{table}'. Check Supabase logs.")
|
31 |
+
return pd.DataFrame()
|
32 |
+
else:
|
33 |
+
st.error("Supabase client not initialized.")
|
34 |
+
return pd.DataFrame()
|
35 |
+
except Exception as e:
|
36 |
+
st.error(f"An error occurred during data loading from table '{table}': {e}")
|
37 |
+
return pd.DataFrame()
|
38 |
+
|
39 |
+
# ---------------------------------------------------------------------------------
|
40 |
+
# Helper Function Definitions
|
41 |
+
# ---------------------------------------------------------------------------------
|
42 |
+
|
43 |
+
def extract_country_from_prompt_regex(question, country_list):
|
44 |
+
"""Extracts the first matching country from the list found in the question."""
|
45 |
+
for country in country_list:
|
46 |
+
# Use word boundaries (\b) for more accurate matching
|
47 |
+
if re.search(r"\b" + re.escape(country) + r"\b", question, re.IGNORECASE):
|
48 |
+
return country
|
49 |
+
return None # Return None if no country in the list is found
|
50 |
+
|
51 |
+
def extract_years_from_prompt(question):
|
52 |
+
"""Extracts a single year or a start/end year range from a question string."""
|
53 |
+
start_year, end_year = None, None
|
54 |
+
# Pattern 1: Single year (e.g., "in 2010", "year 2010")
|
55 |
+
single_year_match = re.search(r'\b(in|year|del)\s+(\d{4})\b', question, re.IGNORECASE)
|
56 |
+
if single_year_match:
|
57 |
+
year = int(single_year_match.group(2))
|
58 |
+
return year, year # Return single year as start and end
|
59 |
+
|
60 |
+
# Pattern 2: Year range (e.g., "between 2000 and 2010", "from 2005 to 2015")
|
61 |
+
range_match = re.search(r'\b(between|from)\s+(\d{4})\s+(and|to)\s+(\d{4})\b', question, re.IGNORECASE)
|
62 |
+
if range_match:
|
63 |
+
s_year = int(range_match.group(2))
|
64 |
+
e_year = int(range_match.group(4))
|
65 |
+
return min(s_year, e_year), max(s_year, e_year) # Ensure start <= end
|
66 |
+
|
67 |
+
# Pattern 3: Simple range like "2000-2010"
|
68 |
+
simple_range_match = re.search(r'\b(\d{4})-(\d{4})\b', question)
|
69 |
+
if simple_range_match:
|
70 |
+
s_year = int(simple_range_match.group(1))
|
71 |
+
e_year = int(simple_range_match.group(2))
|
72 |
+
return min(s_year, e_year), max(s_year, e_year)
|
73 |
+
|
74 |
+
# Pattern 4: After Year (e.g., "after 2015")
|
75 |
+
after_match = re.search(r'\b(after|since)\s+(\d{4})\b', question, re.IGNORECASE)
|
76 |
+
if after_match:
|
77 |
+
start_year = int(after_match.group(2))
|
78 |
+
# end_year remains None, signifying >= start_year
|
79 |
+
|
80 |
+
# Pattern 5: Before Year (e.g., "before 2005")
|
81 |
+
before_match = re.search(r'\b(before)\s+(\d{4})\b', question, re.IGNORECASE)
|
82 |
+
if before_match:
|
83 |
+
end_year = int(before_match.group(2))
|
84 |
+
# start_year remains None, signifying <= end_year
|
85 |
+
# Special case: if 'after' wasn't also found, return (None, end_year)
|
86 |
+
if start_year is None:
|
87 |
+
return None, end_year
|
88 |
+
|
89 |
+
# Return extracted years (could be None, None; start, None; None, end; or start, end)
|
90 |
+
# If single year patterns were matched first, they returned already.
|
91 |
+
return start_year, end_year
|
92 |
+
|
93 |
+
|
94 |
+
def filter_df_by_years(df, year_col, start_year, end_year):
|
95 |
+
"""Filters a DataFrame based on a year column and a start/end year range."""
|
96 |
+
if year_col not in df.columns:
|
97 |
+
st.warning(f"Year column '{year_col}' not found.")
|
98 |
+
return df
|
99 |
+
|
100 |
+
try:
|
101 |
+
# Ensure year column is numeric, coerce errors to NaT/NaN
|
102 |
+
df[year_col] = pd.to_numeric(df[year_col], errors='coerce')
|
103 |
+
# Drop rows where conversion failed, essential for comparison
|
104 |
+
df_filtered = df.dropna(subset=[year_col]).copy()
|
105 |
+
# Convert to integer only AFTER dropping NaN, avoids potential float issues
|
106 |
+
df_filtered[year_col] = df_filtered[year_col].astype(int)
|
107 |
+
except Exception as e:
|
108 |
+
st.error(f"Could not convert year column '{year_col}' to numeric: {e}")
|
109 |
+
return df # Return original on error
|
110 |
+
|
111 |
+
original_count = len(df_filtered) # Count after potential NaNs are dropped
|
112 |
+
|
113 |
+
if start_year is None and end_year is None:
|
114 |
+
# No year filtering needed
|
115 |
+
return df_filtered
|
116 |
+
|
117 |
+
st.info(f"Filtering by years: Start={start_year}, End={end_year} on column '{year_col}'")
|
118 |
+
|
119 |
+
# Apply filters based on provided start/end years
|
120 |
+
if start_year is not None and end_year is not None:
|
121 |
+
# Specific range or single year (where start_year == end_year)
|
122 |
+
df_filtered = df_filtered[(df_filtered[year_col] >= start_year) & (df_filtered[year_col] <= end_year)]
|
123 |
+
elif start_year is not None:
|
124 |
+
# Only start year ("after X")
|
125 |
+
df_filtered = df_filtered[df_filtered[year_col] >= start_year]
|
126 |
+
elif end_year is not None:
|
127 |
+
# Only end year ("before Y")
|
128 |
+
df_filtered = df_filtered[df_filtered[year_col] <= end_year]
|
129 |
+
|
130 |
+
filtered_count = len(df_filtered)
|
131 |
+
if filtered_count == 0 and original_count > 0: # Check if filtering removed all data
|
132 |
+
st.warning(f"No data found for the specified year(s): {start_year if start_year else ''}-{end_year if end_year else ''}")
|
133 |
+
elif filtered_count < original_count:
|
134 |
+
st.write(f"Filtered data by year. Rows reduced from {original_count} to {filtered_count}.")
|
135 |
+
|
136 |
+
return df_filtered
|
137 |
+
|
138 |
+
|
139 |
+
# ---------------------------------------------------------------------------------
|
140 |
+
# Load Model
|
141 |
+
# ---------------------------------------------------------------------------------
|
142 |
+
@st.cache_resource
|
143 |
+
def load_gpt2():
|
144 |
+
try:
|
145 |
+
generator = pipeline('text-generation', model='openai-community/gpt2')
|
146 |
+
return generator
|
147 |
+
except Exception as e:
|
148 |
+
st.error(f"Failed to load GPT-2 model: {e}")
|
149 |
+
return None
|
150 |
+
|
151 |
+
generator = load_gpt2()
|
152 |
+
|
153 |
+
# ---------------------------------------------------------------------------------
|
154 |
+
# Load Initial Data
|
155 |
+
# ---------------------------------------------------------------------------------
|
156 |
+
if 'data_labor' not in st.session_state:
|
157 |
+
st.session_state['data_labor'] = load_data("labor") # Or your default table
|
158 |
+
|
159 |
+
# ---------------------------------------------------------------------------------
|
160 |
+
# Streamlit App UI Starts Here
|
161 |
+
# ---------------------------------------------------------------------------------
|
162 |
+
st.title("Análisis de Datos con GPT-2 y Visualización Automática")
|
163 |
+
|
164 |
+
# Get the dataframe from session state
|
165 |
+
df = st.session_state.get('data_labor')
|
166 |
+
|
167 |
+
# --- Check if DataFrame is loaded ---
|
168 |
+
if df is None or df.empty:
|
169 |
+
st.error("Failed to load data or data is empty. Please check Supabase connection and table 'labor'.")
|
170 |
+
# Optionally add a button to retry loading
|
171 |
+
if st.button("Retry Loading Data"):
|
172 |
+
st.session_state['data_labor'] = load_data("labor")
|
173 |
+
st.rerun() # Rerun the script after attempting reload
|
174 |
+
else:
|
175 |
+
# --- Section for the user question ---
|
176 |
+
st.subheader("Pregúntame algo sobre los datos de 'labor'")
|
177 |
+
question = st.text_input("Ejemplo: 'Cuál fue la fuerza laboral (labor force) en Germany entre 2010 y 2015?'")
|
178 |
+
|
179 |
+
if question:
|
180 |
+
# --- Main processing logic ---
|
181 |
+
st.write("--- Análisis de la pregunta ---") # Debug separator
|
182 |
+
|
183 |
+
# Filter by Country
|
184 |
+
unique_countries = df['geo'].unique().tolist() if 'geo' in df.columns else []
|
185 |
+
extracted_country = extract_country_from_prompt_regex(question, unique_countries)
|
186 |
+
|
187 |
+
filtered_df = df.copy()
|
188 |
+
if extracted_country:
|
189 |
+
if 'geo' in filtered_df.columns:
|
190 |
+
filtered_df = filtered_df[filtered_df['geo'] == extracted_country]
|
191 |
+
st.success(f"Filtrando datos para el país: {extracted_country}")
|
192 |
+
else:
|
193 |
+
st.warning("Columna 'geo' no encontrada para filtrar por país.")
|
194 |
+
else:
|
195 |
+
st.info("No se especificó un país o no se encontró. Mostrando datos para todos los países disponibles.")
|
196 |
+
|
197 |
+
# Identify Columns
|
198 |
+
numerical_cols = [col for col in filtered_df.columns if pd.api.types.is_numeric_dtype(filtered_df[col])]
|
199 |
+
year_col_names = ['year', 'time', 'period', 'año']
|
200 |
+
year_cols = [col for col in filtered_df.columns if col.lower() in year_col_names and col in numerical_cols]
|
201 |
+
categorical_cols = [col for col in filtered_df.columns if pd.api.types.is_object_dtype(filtered_df[col]) and col != 'geo']
|
202 |
+
|
203 |
+
# Extract Years and Filter DataFrame
|
204 |
+
start_year, end_year = extract_years_from_prompt(question)
|
205 |
+
year_col_to_use = None
|
206 |
+
if year_cols:
|
207 |
+
year_col_to_use = year_cols[0]
|
208 |
+
filtered_df = filter_df_by_years(filtered_df, year_col_to_use, start_year, end_year)
|
209 |
+
else:
|
210 |
+
st.warning("No se pudo identificar una columna de año numérica para filtrar.")
|
211 |
+
|
212 |
+
|
213 |
+
# --- GPT-2 Description Generation ---
|
214 |
+
if generator: # Check if model loaded successfully
|
215 |
+
st.subheader("Descripción Automática (GPT-2)")
|
216 |
+
# Create a concise context
|
217 |
+
context_description = "The dataset contains labor data"
|
218 |
+
context_info = f"Data for {extracted_country or 'all countries'}"
|
219 |
+
|
220 |
+
if extracted_country:
|
221 |
+
# If a specific country is filtered, mention it clearly
|
222 |
+
context_description += f" specifically for {extracted_country}"
|
223 |
+
else:
|
224 |
+
# Otherwise, mention the broader scope if known (e.g., Europe)
|
225 |
+
# If you load data for multiple countries by default, state that
|
226 |
+
context_description += " covering multiple countries" # Adjust if needed
|
227 |
+
|
228 |
+
if year_col_to_use and (start_year is not None or end_year is not None):
|
229 |
+
context_info += f" between years {start_year if start_year else 'start'} and {end_year if end_year else 'end'}"
|
230 |
+
context_info += f". Columns include: {', '.join(filtered_df.columns.tolist())}."
|
231 |
+
|
232 |
+
prompt = f"{context_info}\n\nQuestion: {question}\nAnswer based ONLY on the provided context:"
|
233 |
+
try:
|
234 |
+
st.info("Generando descripción...") # Let user know it's working
|
235 |
+
description = generator(prompt, max_new_tokens=200, num_return_sequences=1)[0]['generated_text']
|
236 |
+
# Clean up the output to show only the answer part
|
237 |
+
answer_part = description.split(prompt)[-1] # Split by the prompt itself
|
238 |
+
st.success("Descripción generada:")
|
239 |
+
st.write(answer_part.strip())
|
240 |
+
except Exception as e:
|
241 |
+
st.error(f"Error generando descripción con GPT-2: {e}")
|
242 |
+
else:
|
243 |
+
st.warning("El modelo GPT-2 no está cargado. No se puede generar descripción.")
|
244 |
+
|
245 |
+
|
246 |
+
# --- Visualization Section ---
|
247 |
+
st.subheader("Visualización Automática")
|
248 |
+
|
249 |
+
if filtered_df.empty:
|
250 |
+
st.warning("No hay datos para mostrar después de aplicar los filtros.")
|
251 |
+
|
252 |
+
# --- Logic for LINE PLOT ---
|
253 |
+
elif year_col_to_use and numerical_cols:
|
254 |
+
start_time_graph = time.time()
|
255 |
+
potential_y_cols = [col for col in numerical_cols if col != year_col_to_use]
|
256 |
+
y_col = None
|
257 |
+
if not potential_y_cols:
|
258 |
+
st.warning(f"No se encontraron columnas numéricas de datos (aparte de '{year_col_to_use}') para graficar contra el año.")
|
259 |
+
else:
|
260 |
+
labor_keywords = ['labor', 'labour', 'workforce', 'employment', 'lfpr', 'fuerza'] # Added 'fuerza'
|
261 |
+
found_labor_col = False
|
262 |
+
for col in potential_y_cols:
|
263 |
+
if any(keyword in col.lower() for keyword in labor_keywords):
|
264 |
+
y_col = col
|
265 |
+
st.info(f"Se encontró columna relevante: '{y_col}'. Usándola para el eje Y.")
|
266 |
+
found_labor_col = True
|
267 |
+
break
|
268 |
+
if not found_labor_col:
|
269 |
+
y_col = potential_y_cols[0]
|
270 |
+
st.info(f"No se encontró columna específica. Usando la primera columna numérica disponible ('{y_col}') para el eje Y.")
|
271 |
+
|
272 |
+
if y_col:
|
273 |
+
x_col = year_col_to_use
|
274 |
+
fig = go.Figure()
|
275 |
+
title = f"{y_col} vs {x_col}"
|
276 |
+
if extracted_country:
|
277 |
+
title += f" en {extracted_country}"
|
278 |
+
if start_year is not None or end_year is not None:
|
279 |
+
year_range_str = ""
|
280 |
+
if start_year is not None:
|
281 |
+
year_range_str += str(start_year)
|
282 |
+
if end_year is not None:
|
283 |
+
year_range_str += f"-{end_year}" if start_year is not None else str(end_year)
|
284 |
+
if year_range_str:
|
285 |
+
title += f" ({year_range_str})"
|
286 |
+
|
287 |
+
df_plot = filtered_df.sort_values(by=x_col)
|
288 |
+
if y_col in df_plot.columns and x_col in df_plot.columns:
|
289 |
+
# Add color based on 'sex' if available
|
290 |
+
if 'sex' in df_plot.columns:
|
291 |
+
for sex_val in df_plot['sex'].unique():
|
292 |
+
df_subset = df_plot[df_plot['sex'] == sex_val]
|
293 |
+
fig.add_trace(go.Scatter(x=df_subset[x_col], y=df_subset[y_col], mode='lines+markers', name=str(sex_val)))
|
294 |
+
fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col)
|
295 |
+
else:
|
296 |
+
fig.add_trace(go.Scatter(x=df_plot[x_col], y=df_plot[y_col], mode='lines+markers', name=y_col))
|
297 |
+
fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col)
|
298 |
+
st.plotly_chart(fig)
|
299 |
+
end_time_graph = time.time()
|
300 |
+
st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos")
|
301 |
+
else:
|
302 |
+
st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.")
|
303 |
+
|
304 |
+
# --- Logic for SCATTER PLOT ---
|
305 |
+
elif numerical_cols and len(numerical_cols) >= 2:
|
306 |
+
start_time_graph = time.time()
|
307 |
+
st.subheader("Gráfico de Dispersión Sugerido")
|
308 |
+
col1 = st.selectbox("Selecciona la primera columna numérica para el gráfico de dispersión:", numerical_cols)
|
309 |
+
col2 = st.selectbox("Selecciona la segunda columna numérica para el gráfico de dispersión:", [c for c in numerical_cols if c != col1])
|
310 |
+
if col1 and col2:
|
311 |
+
fig = px.scatter(filtered_df, x=col1, y=col2, title=f"Gráfico de Dispersión: {col1} vs {col2}")
|
312 |
+
st.plotly_chart(fig)
|
313 |
+
end_time_graph = time.time()
|
314 |
+
st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos")
|
315 |
+
else:
|
316 |
+
st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.")
|
317 |
+
|
318 |
+
# --- Logic for SCATTER PLOT ---
|
319 |
+
# (Your scatter plot logic here...)
|
320 |
+
elif numerical_cols and len(numerical_cols) >= (2 + (1 if year_col_to_use else 0)) :
|
321 |
+
# ... (scatter plot code, ensuring cols exist) ...
|
322 |
+
pass # Placeholder
|
323 |
+
|
324 |
+
# --- Logic for BAR CHART ---
|
325 |
+
# (Your bar chart logic here...)
|
326 |
+
elif numerical_cols and categorical_cols:
|
327 |
+
# ... (bar chart code, ensuring cols exist and aggregating if needed) ...
|
328 |
+
pass # Placeholder
|
329 |
+
else:
|
330 |
+
# Only show this if no plots were generated above
|
331 |
+
if not (year_col_to_use and y_col): # Check if line plot was attempted
|
332 |
+
st.info("No se encontraron columnas adecuadas o suficientes datos después del filtrado para generar un gráfico automáticamente.")
|