juancamval commited on
Commit
ab15ee1
·
verified ·
1 Parent(s): 6f70ef7

Upload 2 files

Browse files

Docs donde se corren los test de GPT2 y Gemma3

Files changed (2) hide show
  1. gemma3test.py +272 -0
  2. gpt2test.py +332 -0
gemma3test.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import time
4
+ import pandas as pd
5
+ from dotenv import load_dotenv
6
+ from supabase import create_client, Client
7
+ from transformers import pipeline
8
+ from sentence_transformers import SentenceTransformer
9
+ import plotly.graph_objects as go
10
+ import numpy as np
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ import re
13
+
14
+
15
+ # ---------------------------------------------------------------------------------
16
+ # Funciones auxiliares
17
+ # ---------------------------------------------------------------------------------
18
+
19
+ def extract_country_and_dates(prompt, countries):
20
+ country = None
21
+ start_date = None
22
+ end_date = None
23
+
24
+ # Buscar el país (insensible a mayúsculas y minúsculas)
25
+ for c in countries:
26
+ if re.search(r'\b' + re.escape(c) + r'\b', prompt, re.IGNORECASE):
27
+ country = c
28
+ break
29
+
30
+ # Buscar rangos de años con diferentes separadores (-, to, until, from ... to, between ... and)
31
+ date_ranges = re.findall(r'(\d{4})\s*(?:-|to|until|from.*?to|between.*?and)\s*(\d{4})', prompt, re.IGNORECASE)
32
+ if date_ranges:
33
+ start_date = date_ranges[0][0]
34
+ end_date = date_ranges[0][1]
35
+ else:
36
+ # Buscar un solo año
37
+ single_years = re.findall(r'\b(\d{4})\b', prompt)
38
+ if single_years:
39
+ start_date = single_years[0]
40
+ end_date = single_years[0]
41
+
42
+ return country, start_date, end_date
43
+
44
+
45
+ def generate_plotly_graph(df, user_query, country=None, start_date=None, end_date=None):
46
+ relevant_data = df.copy()
47
+
48
+ if 'geo' in relevant_data.columns and country:
49
+ relevant_data = relevant_data[relevant_data['geo'].str.lower() == country.lower()]
50
+
51
+ if 'year' in relevant_data.columns:
52
+ relevant_data['year'] = pd.to_numeric(relevant_data['year'], errors='coerce').dropna().astype(int)
53
+ if start_date and end_date:
54
+ relevant_data = relevant_data[
55
+ (relevant_data['year'] >= int(start_date)) & (relevant_data['year'] <= int(end_date))
56
+ ]
57
+ elif start_date:
58
+ relevant_data = relevant_data[relevant_data['year'] >= int(start_date)]
59
+ elif end_date:
60
+ relevant_data = relevant_data[relevant_data['year'] <= int(end_date)]
61
+
62
+ numeric_cols = relevant_data.select_dtypes(include=['number']).columns.tolist()
63
+ if 'year' in relevant_data.columns and numeric_cols:
64
+ fig = go.Figure()
65
+ for col in numeric_cols:
66
+ if col != 'year':
67
+ fig.add_trace(go.Scatter(x=relevant_data['year'], y=relevant_data[col], mode='lines+markers', name=col))
68
+
69
+ title = f"Data for {country if country else 'All Regions'}"
70
+ if start_date and end_date:
71
+ title += f" ({start_date}-{end_date})"
72
+ elif start_date:
73
+ title += f" (from {start_date})"
74
+ elif end_date:
75
+ title += f" (up to {end_date})"
76
+
77
+ # Añadir título y etiquetas de los ejes
78
+ fig.update_layout(
79
+ title=title,
80
+ xaxis_title="Year",
81
+ yaxis_title="Value" # Necesitaremos inferir o tener nombres de columnas más descriptivos
82
+ )
83
+ return fig
84
+ else:
85
+ return None
86
+
87
+
88
+ # ---------------------------------------------------------------------------------
89
+ # Configuración de conexión a Supabase
90
+ # ---------------------------------------------------------------------------------
91
+ load_dotenv()
92
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
93
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
94
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
95
+
96
+
97
+ # Función para cargar datos de una tabla de Supabase
98
+ def load_data(table):
99
+ try:
100
+ if supabase:
101
+ response = supabase.from_(table).select("*").execute()
102
+ if hasattr(response, 'data'):
103
+ return pd.DataFrame(response.data)
104
+ elif hasattr(response, '_error'):
105
+ st.error(f"Error fetching data: {response._error}")
106
+ return pd.DataFrame()
107
+ else:
108
+ st.info("Response object does not have 'data' or known error attributes. Check the logs.")
109
+ return pd.DataFrame()
110
+ else:
111
+ st.error("Supabase client not initialized. Check environment variables.")
112
+ return pd.DataFrame()
113
+ except Exception as e:
114
+ st.error(f"An error occurred during data loading: {e}")
115
+ return pd.DataFrame()
116
+
117
+
118
+ # ---------------------------------------------------------------------------------
119
+ # Cargar datos iniciales
120
+ # ---------------------------------------------------------------------------------
121
+ labor_data = load_data("labor")
122
+ fertility_data = load_data("fertility")
123
+
124
+ # ---------------------------------------------------------------------------------
125
+ # Inicialización de modelos para RAG
126
+ # ---------------------------------------------------------------------------------
127
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
128
+ llm_pipeline = pipeline("text-generation", model="google/gemma-3-1b-it", token=os.getenv("HF_TOKEN"))
129
+
130
+ # ---------------------------------------------------------------------------------
131
+ # Generación de Embeddings y Metadatos (en memoria)
132
+ # ---------------------------------------------------------------------------------
133
+ embeddings_list = []
134
+ contents_list = []
135
+ metadatas_list = []
136
+ ids_list = []
137
+
138
+ for index, row in labor_data.iterrows():
139
+ doc = f"Country: {row['geo']}, Year: {row['year']}, Employment Rate: {row['labour_force'] if 'labour_force' in row else 'N/A'}"
140
+ embeddings_list.append(embedding_model.encode(doc))
141
+ contents_list.append(doc)
142
+ metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'labor'})
143
+ ids_list.append(f"labor_{index}")
144
+
145
+ for index, row in fertility_data.iterrows():
146
+ doc = f"Country: {row['geo']}, Year: {row['year']}, Fertility Rate: {row['fertility_rate'] if 'fertility_rate' in row else 'N/A'}"
147
+ embeddings_list.append(embedding_model.encode(doc))
148
+ contents_list.append(doc)
149
+ metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'fertility'})
150
+ ids_list.append(f"fertility_{index}")
151
+
152
+ embeddings_array = np.array(embeddings_list)
153
+
154
+
155
+ # ---------------------------------------------------------------------------------
156
+ # Función para recuperar documentos relevantes (en memoria)
157
+ # ---------------------------------------------------------------------------------
158
+ def retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents, top_k=3):
159
+ similarities = cosine_similarity([query_embedding], stored_embeddings)[0]
160
+ sorted_indices = np.argsort(similarities)[::-1]
161
+ relevant_documents = [contents[i] for i in sorted_indices[:top_k]]
162
+ return relevant_documents
163
+
164
+
165
+ # ---------------------------------------------------------------------------------
166
+ # Generación de la explicación usando RAG
167
+ # ---------------------------------------------------------------------------------
168
+ def generate_rag_explanation(user_query, stored_embeddings, contents):
169
+ query_embedding = embedding_model.encode(user_query)
170
+ relevant_docs = retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents)
171
+ if relevant_docs:
172
+ context = "\n".join(relevant_docs)
173
+ augmented_prompt = f"Based on the following information:\n\n{context}\n\nAnswer the question related to: {user_query}"
174
+ output = llm_pipeline(augmented_prompt, max_length=250, num_return_sequences=1)
175
+ return output[0]['generated_text']
176
+ else:
177
+ return "No relevant information found to answer your query."
178
+
179
+
180
+ # ---------------------------------------------------------------------------------
181
+ # Generar la lista de países automáticamente
182
+ # ---------------------------------------------------------------------------------
183
+ available_countries_labor = labor_data['geo'].unique().tolist() if 'geo' in labor_data.columns else []
184
+ available_countries_fertility = fertility_data['geo'].unique().tolist() if 'geo' in fertility_data.columns else []
185
+ all_countries = list(set(available_countries_labor + available_countries_fertility))
186
+
187
+ # ---------------------------------------------------------------------------------
188
+ # Configuración de la app en Streamlit
189
+ # ---------------------------------------------------------------------------------
190
+ st.set_page_config(page_title="GraphGen", page_icon="🇪🇺")
191
+ st.title("_Europe GraphGen_  :blue[Graph generator] :flag-eu:")
192
+ st.caption("Mapping Europe's data with insights")
193
+
194
+ if "messages" not in st.session_state:
195
+ st.session_state.messages = []
196
+ st.session_state.messages.append({"role": "assistant", "content": "What graphic and insights do you need?"})
197
+
198
+ for message in st.session_state.messages:
199
+ with st.chat_message(message["role"]):
200
+ st.markdown(message["content"])
201
+
202
+ prompt = st.chat_input("Type your message here...", key="chat_input_bottom")
203
+
204
+ if prompt:
205
+ st.session_state.messages.append({"role": "user", "content": prompt})
206
+ with st.chat_message("user"):
207
+ st.markdown(prompt)
208
+
209
+ with st.spinner('Generating answer...'):
210
+
211
+ try:
212
+ # Determinar el año más reciente en los datos
213
+ latest_year_labor = labor_data['year'].max() if 'year' in labor_data else datetime.now().year
214
+ latest_year_fertility = fertility_data['year'].max() if 'year' in fertility_data else datetime.now().year
215
+ latest_year = max(latest_year_labor, latest_year_fertility, datetime.now().year)
216
+
217
+ country, start_date, end_date = extract_country_and_dates(prompt, all_countries, latest_year)
218
+ graph_displayed = False
219
+
220
+ # Analizar el prompt para determinar la intención del usuario
221
+ if re.search(r'\b(labor|employment|job|workforce)\b', prompt, re.IGNORECASE):
222
+ # Generar gráfica de datos laborales
223
+ labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date)
224
+ if labor_fig:
225
+ st.session_state.messages.append(
226
+ {"role": "assistant", "content": "Here is the labor data graphic:"})
227
+ with st.chat_message("assistant"):
228
+ st.plotly_chart(labor_fig)
229
+ graph_displayed = True
230
+ elif re.search(r'\b(fertility|birth|population growth)\b', prompt, re.IGNORECASE):
231
+ # Generar gráfica de datos de fertilidad
232
+ fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date)
233
+ if fertility_fig:
234
+ st.session_state.messages.append(
235
+ {"role": "assistant", "content": "Here is the fertility data graphic:"})
236
+ with st.chat_message("assistant"):
237
+ st.plotly_chart(fertility_fig)
238
+ graph_displayed = True
239
+ else:
240
+ # Si no se identifica una intención clara, intentar mostrar la gráfica de datos laborales primero
241
+ labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date)
242
+ if labor_fig:
243
+ st.session_state.messages.append(
244
+ {"role": "assistant", "content": "Here is the labor data graphic:"})
245
+ with st.chat_message("assistant"):
246
+ st.plotly_chart(labor_fig)
247
+ graph_displayed = True
248
+ elif not graph_displayed:
249
+ fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date)
250
+ if fertility_fig:
251
+ st.session_state.messages.append(
252
+ {"role": "assistant", "content": "Here is the fertility data graphic:"})
253
+ with st.chat_message("assistant"):
254
+ st.plotly_chart(fertility_fig)
255
+ graph_displayed = True
256
+
257
+ # Generar explicación usando RAG
258
+ explanation = generate_rag_explanation(prompt, embeddings_array, contents_list)
259
+ st.session_state.messages.append({"role": "assistant", "content": f"Explanation: {explanation}"})
260
+ with st.chat_message("assistant"):
261
+ st.markdown(f"**Explanation:** {explanation}")
262
+
263
+ except Exception as e:
264
+ st.session_state.messages.append({"role": "assistant", "content": f"Error generating answer: {e}"})
265
+ with st.chat_message("assistant"):
266
+ st.error(f"Error generating answer: {e}")
267
+
268
+ if st.button("Clear chat"):
269
+ st.session_state.messages = []
270
+ st.session_state.messages.append(
271
+ {"role": "assistant", "content": "Chat has been cleared. What graphic and insights do you need now?"})
272
+ st.rerun()
gpt2test.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import re
4
+ import pandas as pd
5
+ from dotenv import load_dotenv
6
+ from supabase import create_client, Client
7
+ from transformers import pipeline
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ import time
11
+
12
+ # ---------------------------------------------------------------------------------
13
+ # Supabase Setup
14
+ # ---------------------------------------------------------------------------------
15
+ load_dotenv()
16
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
17
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
18
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
19
+
20
+ # ---------------------------------------------------------------------------------
21
+ # Data Loading Function
22
+ # ---------------------------------------------------------------------------------
23
+ def load_data(table):
24
+ try:
25
+ if supabase:
26
+ response = supabase.from_(table).select("*").execute()
27
+ if hasattr(response, 'data'):
28
+ return pd.DataFrame(response.data)
29
+ else:
30
+ st.error(f"Error fetching data or no data returned for table '{table}'. Check Supabase logs.")
31
+ return pd.DataFrame()
32
+ else:
33
+ st.error("Supabase client not initialized.")
34
+ return pd.DataFrame()
35
+ except Exception as e:
36
+ st.error(f"An error occurred during data loading from table '{table}': {e}")
37
+ return pd.DataFrame()
38
+
39
+ # ---------------------------------------------------------------------------------
40
+ # Helper Function Definitions
41
+ # ---------------------------------------------------------------------------------
42
+
43
+ def extract_country_from_prompt_regex(question, country_list):
44
+ """Extracts the first matching country from the list found in the question."""
45
+ for country in country_list:
46
+ # Use word boundaries (\b) for more accurate matching
47
+ if re.search(r"\b" + re.escape(country) + r"\b", question, re.IGNORECASE):
48
+ return country
49
+ return None # Return None if no country in the list is found
50
+
51
+ def extract_years_from_prompt(question):
52
+ """Extracts a single year or a start/end year range from a question string."""
53
+ start_year, end_year = None, None
54
+ # Pattern 1: Single year (e.g., "in 2010", "year 2010")
55
+ single_year_match = re.search(r'\b(in|year|del)\s+(\d{4})\b', question, re.IGNORECASE)
56
+ if single_year_match:
57
+ year = int(single_year_match.group(2))
58
+ return year, year # Return single year as start and end
59
+
60
+ # Pattern 2: Year range (e.g., "between 2000 and 2010", "from 2005 to 2015")
61
+ range_match = re.search(r'\b(between|from)\s+(\d{4})\s+(and|to)\s+(\d{4})\b', question, re.IGNORECASE)
62
+ if range_match:
63
+ s_year = int(range_match.group(2))
64
+ e_year = int(range_match.group(4))
65
+ return min(s_year, e_year), max(s_year, e_year) # Ensure start <= end
66
+
67
+ # Pattern 3: Simple range like "2000-2010"
68
+ simple_range_match = re.search(r'\b(\d{4})-(\d{4})\b', question)
69
+ if simple_range_match:
70
+ s_year = int(simple_range_match.group(1))
71
+ e_year = int(simple_range_match.group(2))
72
+ return min(s_year, e_year), max(s_year, e_year)
73
+
74
+ # Pattern 4: After Year (e.g., "after 2015")
75
+ after_match = re.search(r'\b(after|since)\s+(\d{4})\b', question, re.IGNORECASE)
76
+ if after_match:
77
+ start_year = int(after_match.group(2))
78
+ # end_year remains None, signifying >= start_year
79
+
80
+ # Pattern 5: Before Year (e.g., "before 2005")
81
+ before_match = re.search(r'\b(before)\s+(\d{4})\b', question, re.IGNORECASE)
82
+ if before_match:
83
+ end_year = int(before_match.group(2))
84
+ # start_year remains None, signifying <= end_year
85
+ # Special case: if 'after' wasn't also found, return (None, end_year)
86
+ if start_year is None:
87
+ return None, end_year
88
+
89
+ # Return extracted years (could be None, None; start, None; None, end; or start, end)
90
+ # If single year patterns were matched first, they returned already.
91
+ return start_year, end_year
92
+
93
+
94
+ def filter_df_by_years(df, year_col, start_year, end_year):
95
+ """Filters a DataFrame based on a year column and a start/end year range."""
96
+ if year_col not in df.columns:
97
+ st.warning(f"Year column '{year_col}' not found.")
98
+ return df
99
+
100
+ try:
101
+ # Ensure year column is numeric, coerce errors to NaT/NaN
102
+ df[year_col] = pd.to_numeric(df[year_col], errors='coerce')
103
+ # Drop rows where conversion failed, essential for comparison
104
+ df_filtered = df.dropna(subset=[year_col]).copy()
105
+ # Convert to integer only AFTER dropping NaN, avoids potential float issues
106
+ df_filtered[year_col] = df_filtered[year_col].astype(int)
107
+ except Exception as e:
108
+ st.error(f"Could not convert year column '{year_col}' to numeric: {e}")
109
+ return df # Return original on error
110
+
111
+ original_count = len(df_filtered) # Count after potential NaNs are dropped
112
+
113
+ if start_year is None and end_year is None:
114
+ # No year filtering needed
115
+ return df_filtered
116
+
117
+ st.info(f"Filtering by years: Start={start_year}, End={end_year} on column '{year_col}'")
118
+
119
+ # Apply filters based on provided start/end years
120
+ if start_year is not None and end_year is not None:
121
+ # Specific range or single year (where start_year == end_year)
122
+ df_filtered = df_filtered[(df_filtered[year_col] >= start_year) & (df_filtered[year_col] <= end_year)]
123
+ elif start_year is not None:
124
+ # Only start year ("after X")
125
+ df_filtered = df_filtered[df_filtered[year_col] >= start_year]
126
+ elif end_year is not None:
127
+ # Only end year ("before Y")
128
+ df_filtered = df_filtered[df_filtered[year_col] <= end_year]
129
+
130
+ filtered_count = len(df_filtered)
131
+ if filtered_count == 0 and original_count > 0: # Check if filtering removed all data
132
+ st.warning(f"No data found for the specified year(s): {start_year if start_year else ''}-{end_year if end_year else ''}")
133
+ elif filtered_count < original_count:
134
+ st.write(f"Filtered data by year. Rows reduced from {original_count} to {filtered_count}.")
135
+
136
+ return df_filtered
137
+
138
+
139
+ # ---------------------------------------------------------------------------------
140
+ # Load Model
141
+ # ---------------------------------------------------------------------------------
142
+ @st.cache_resource
143
+ def load_gpt2():
144
+ try:
145
+ generator = pipeline('text-generation', model='openai-community/gpt2')
146
+ return generator
147
+ except Exception as e:
148
+ st.error(f"Failed to load GPT-2 model: {e}")
149
+ return None
150
+
151
+ generator = load_gpt2()
152
+
153
+ # ---------------------------------------------------------------------------------
154
+ # Load Initial Data
155
+ # ---------------------------------------------------------------------------------
156
+ if 'data_labor' not in st.session_state:
157
+ st.session_state['data_labor'] = load_data("labor") # Or your default table
158
+
159
+ # ---------------------------------------------------------------------------------
160
+ # Streamlit App UI Starts Here
161
+ # ---------------------------------------------------------------------------------
162
+ st.title("Análisis de Datos con GPT-2 y Visualización Automática")
163
+
164
+ # Get the dataframe from session state
165
+ df = st.session_state.get('data_labor')
166
+
167
+ # --- Check if DataFrame is loaded ---
168
+ if df is None or df.empty:
169
+ st.error("Failed to load data or data is empty. Please check Supabase connection and table 'labor'.")
170
+ # Optionally add a button to retry loading
171
+ if st.button("Retry Loading Data"):
172
+ st.session_state['data_labor'] = load_data("labor")
173
+ st.rerun() # Rerun the script after attempting reload
174
+ else:
175
+ # --- Section for the user question ---
176
+ st.subheader("Pregúntame algo sobre los datos de 'labor'")
177
+ question = st.text_input("Ejemplo: 'Cuál fue la fuerza laboral (labor force) en Germany entre 2010 y 2015?'")
178
+
179
+ if question:
180
+ # --- Main processing logic ---
181
+ st.write("--- Análisis de la pregunta ---") # Debug separator
182
+
183
+ # Filter by Country
184
+ unique_countries = df['geo'].unique().tolist() if 'geo' in df.columns else []
185
+ extracted_country = extract_country_from_prompt_regex(question, unique_countries)
186
+
187
+ filtered_df = df.copy()
188
+ if extracted_country:
189
+ if 'geo' in filtered_df.columns:
190
+ filtered_df = filtered_df[filtered_df['geo'] == extracted_country]
191
+ st.success(f"Filtrando datos para el país: {extracted_country}")
192
+ else:
193
+ st.warning("Columna 'geo' no encontrada para filtrar por país.")
194
+ else:
195
+ st.info("No se especificó un país o no se encontró. Mostrando datos para todos los países disponibles.")
196
+
197
+ # Identify Columns
198
+ numerical_cols = [col for col in filtered_df.columns if pd.api.types.is_numeric_dtype(filtered_df[col])]
199
+ year_col_names = ['year', 'time', 'period', 'año']
200
+ year_cols = [col for col in filtered_df.columns if col.lower() in year_col_names and col in numerical_cols]
201
+ categorical_cols = [col for col in filtered_df.columns if pd.api.types.is_object_dtype(filtered_df[col]) and col != 'geo']
202
+
203
+ # Extract Years and Filter DataFrame
204
+ start_year, end_year = extract_years_from_prompt(question)
205
+ year_col_to_use = None
206
+ if year_cols:
207
+ year_col_to_use = year_cols[0]
208
+ filtered_df = filter_df_by_years(filtered_df, year_col_to_use, start_year, end_year)
209
+ else:
210
+ st.warning("No se pudo identificar una columna de año numérica para filtrar.")
211
+
212
+
213
+ # --- GPT-2 Description Generation ---
214
+ if generator: # Check if model loaded successfully
215
+ st.subheader("Descripción Automática (GPT-2)")
216
+ # Create a concise context
217
+ context_description = "The dataset contains labor data"
218
+ context_info = f"Data for {extracted_country or 'all countries'}"
219
+
220
+ if extracted_country:
221
+ # If a specific country is filtered, mention it clearly
222
+ context_description += f" specifically for {extracted_country}"
223
+ else:
224
+ # Otherwise, mention the broader scope if known (e.g., Europe)
225
+ # If you load data for multiple countries by default, state that
226
+ context_description += " covering multiple countries" # Adjust if needed
227
+
228
+ if year_col_to_use and (start_year is not None or end_year is not None):
229
+ context_info += f" between years {start_year if start_year else 'start'} and {end_year if end_year else 'end'}"
230
+ context_info += f". Columns include: {', '.join(filtered_df.columns.tolist())}."
231
+
232
+ prompt = f"{context_info}\n\nQuestion: {question}\nAnswer based ONLY on the provided context:"
233
+ try:
234
+ st.info("Generando descripción...") # Let user know it's working
235
+ description = generator(prompt, max_new_tokens=200, num_return_sequences=1)[0]['generated_text']
236
+ # Clean up the output to show only the answer part
237
+ answer_part = description.split(prompt)[-1] # Split by the prompt itself
238
+ st.success("Descripción generada:")
239
+ st.write(answer_part.strip())
240
+ except Exception as e:
241
+ st.error(f"Error generando descripción con GPT-2: {e}")
242
+ else:
243
+ st.warning("El modelo GPT-2 no está cargado. No se puede generar descripción.")
244
+
245
+
246
+ # --- Visualization Section ---
247
+ st.subheader("Visualización Automática")
248
+
249
+ if filtered_df.empty:
250
+ st.warning("No hay datos para mostrar después de aplicar los filtros.")
251
+
252
+ # --- Logic for LINE PLOT ---
253
+ elif year_col_to_use and numerical_cols:
254
+ start_time_graph = time.time()
255
+ potential_y_cols = [col for col in numerical_cols if col != year_col_to_use]
256
+ y_col = None
257
+ if not potential_y_cols:
258
+ st.warning(f"No se encontraron columnas numéricas de datos (aparte de '{year_col_to_use}') para graficar contra el año.")
259
+ else:
260
+ labor_keywords = ['labor', 'labour', 'workforce', 'employment', 'lfpr', 'fuerza'] # Added 'fuerza'
261
+ found_labor_col = False
262
+ for col in potential_y_cols:
263
+ if any(keyword in col.lower() for keyword in labor_keywords):
264
+ y_col = col
265
+ st.info(f"Se encontró columna relevante: '{y_col}'. Usándola para el eje Y.")
266
+ found_labor_col = True
267
+ break
268
+ if not found_labor_col:
269
+ y_col = potential_y_cols[0]
270
+ st.info(f"No se encontró columna específica. Usando la primera columna numérica disponible ('{y_col}') para el eje Y.")
271
+
272
+ if y_col:
273
+ x_col = year_col_to_use
274
+ fig = go.Figure()
275
+ title = f"{y_col} vs {x_col}"
276
+ if extracted_country:
277
+ title += f" en {extracted_country}"
278
+ if start_year is not None or end_year is not None:
279
+ year_range_str = ""
280
+ if start_year is not None:
281
+ year_range_str += str(start_year)
282
+ if end_year is not None:
283
+ year_range_str += f"-{end_year}" if start_year is not None else str(end_year)
284
+ if year_range_str:
285
+ title += f" ({year_range_str})"
286
+
287
+ df_plot = filtered_df.sort_values(by=x_col)
288
+ if y_col in df_plot.columns and x_col in df_plot.columns:
289
+ # Add color based on 'sex' if available
290
+ if 'sex' in df_plot.columns:
291
+ for sex_val in df_plot['sex'].unique():
292
+ df_subset = df_plot[df_plot['sex'] == sex_val]
293
+ fig.add_trace(go.Scatter(x=df_subset[x_col], y=df_subset[y_col], mode='lines+markers', name=str(sex_val)))
294
+ fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col)
295
+ else:
296
+ fig.add_trace(go.Scatter(x=df_plot[x_col], y=df_plot[y_col], mode='lines+markers', name=y_col))
297
+ fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col)
298
+ st.plotly_chart(fig)
299
+ end_time_graph = time.time()
300
+ st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos")
301
+ else:
302
+ st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.")
303
+
304
+ # --- Logic for SCATTER PLOT ---
305
+ elif numerical_cols and len(numerical_cols) >= 2:
306
+ start_time_graph = time.time()
307
+ st.subheader("Gráfico de Dispersión Sugerido")
308
+ col1 = st.selectbox("Selecciona la primera columna numérica para el gráfico de dispersión:", numerical_cols)
309
+ col2 = st.selectbox("Selecciona la segunda columna numérica para el gráfico de dispersión:", [c for c in numerical_cols if c != col1])
310
+ if col1 and col2:
311
+ fig = px.scatter(filtered_df, x=col1, y=col2, title=f"Gráfico de Dispersión: {col1} vs {col2}")
312
+ st.plotly_chart(fig)
313
+ end_time_graph = time.time()
314
+ st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos")
315
+ else:
316
+ st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.")
317
+
318
+ # --- Logic for SCATTER PLOT ---
319
+ # (Your scatter plot logic here...)
320
+ elif numerical_cols and len(numerical_cols) >= (2 + (1 if year_col_to_use else 0)) :
321
+ # ... (scatter plot code, ensuring cols exist) ...
322
+ pass # Placeholder
323
+
324
+ # --- Logic for BAR CHART ---
325
+ # (Your bar chart logic here...)
326
+ elif numerical_cols and categorical_cols:
327
+ # ... (bar chart code, ensuring cols exist and aggregating if needed) ...
328
+ pass # Placeholder
329
+ else:
330
+ # Only show this if no plots were generated above
331
+ if not (year_col_to_use and y_col): # Check if line plot was attempted
332
+ st.info("No se encontraron columnas adecuadas o suficientes datos después del filtrado para generar un gráfico automáticamente.")