import streamlit as st import os import re import pandas as pd from dotenv import load_dotenv from supabase import create_client, Client from transformers import pipeline import plotly.express as px import plotly.graph_objects as go import time # --------------------------------------------------------------------------------- # Supabase Setup # --------------------------------------------------------------------------------- load_dotenv() SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_KEY") supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) # --------------------------------------------------------------------------------- # Data Loading Function # --------------------------------------------------------------------------------- def load_data(table): try: if supabase: response = supabase.from_(table).select("*").execute() if hasattr(response, 'data'): return pd.DataFrame(response.data) else: st.error(f"Error fetching data or no data returned for table '{table}'. Check Supabase logs.") return pd.DataFrame() else: st.error("Supabase client not initialized.") return pd.DataFrame() except Exception as e: st.error(f"An error occurred during data loading from table '{table}': {e}") return pd.DataFrame() # --------------------------------------------------------------------------------- # Helper Function Definitions # --------------------------------------------------------------------------------- def extract_country_from_prompt_regex(question, country_list): """Extracts the first matching country from the list found in the question.""" for country in country_list: # Use word boundaries (\b) for more accurate matching if re.search(r"\b" + re.escape(country) + r"\b", question, re.IGNORECASE): return country return None # Return None if no country in the list is found def extract_years_from_prompt(question): """Extracts a single year or a start/end year range from a question string.""" start_year, end_year = None, None # Pattern 1: Single year (e.g., "in 2010", "year 2010") single_year_match = re.search(r'\b(in|year|del)\s+(\d{4})\b', question, re.IGNORECASE) if single_year_match: year = int(single_year_match.group(2)) return year, year # Return single year as start and end # Pattern 2: Year range (e.g., "between 2000 and 2010", "from 2005 to 2015") range_match = re.search(r'\b(between|from)\s+(\d{4})\s+(and|to)\s+(\d{4})\b', question, re.IGNORECASE) if range_match: s_year = int(range_match.group(2)) e_year = int(range_match.group(4)) return min(s_year, e_year), max(s_year, e_year) # Ensure start <= end # Pattern 3: Simple range like "2000-2010" simple_range_match = re.search(r'\b(\d{4})-(\d{4})\b', question) if simple_range_match: s_year = int(simple_range_match.group(1)) e_year = int(simple_range_match.group(2)) return min(s_year, e_year), max(s_year, e_year) # Pattern 4: After Year (e.g., "after 2015") after_match = re.search(r'\b(after|since)\s+(\d{4})\b', question, re.IGNORECASE) if after_match: start_year = int(after_match.group(2)) # end_year remains None, signifying >= start_year # Pattern 5: Before Year (e.g., "before 2005") before_match = re.search(r'\b(before)\s+(\d{4})\b', question, re.IGNORECASE) if before_match: end_year = int(before_match.group(2)) # start_year remains None, signifying <= end_year # Special case: if 'after' wasn't also found, return (None, end_year) if start_year is None: return None, end_year # Return extracted years (could be None, None; start, None; None, end; or start, end) # If single year patterns were matched first, they returned already. return start_year, end_year def filter_df_by_years(df, year_col, start_year, end_year): """Filters a DataFrame based on a year column and a start/end year range.""" if year_col not in df.columns: st.warning(f"Year column '{year_col}' not found.") return df try: # Ensure year column is numeric, coerce errors to NaT/NaN df[year_col] = pd.to_numeric(df[year_col], errors='coerce') # Drop rows where conversion failed, essential for comparison df_filtered = df.dropna(subset=[year_col]).copy() # Convert to integer only AFTER dropping NaN, avoids potential float issues df_filtered[year_col] = df_filtered[year_col].astype(int) except Exception as e: st.error(f"Could not convert year column '{year_col}' to numeric: {e}") return df # Return original on error original_count = len(df_filtered) # Count after potential NaNs are dropped if start_year is None and end_year is None: # No year filtering needed return df_filtered st.info(f"Filtering by years: Start={start_year}, End={end_year} on column '{year_col}'") # Apply filters based on provided start/end years if start_year is not None and end_year is not None: # Specific range or single year (where start_year == end_year) df_filtered = df_filtered[(df_filtered[year_col] >= start_year) & (df_filtered[year_col] <= end_year)] elif start_year is not None: # Only start year ("after X") df_filtered = df_filtered[df_filtered[year_col] >= start_year] elif end_year is not None: # Only end year ("before Y") df_filtered = df_filtered[df_filtered[year_col] <= end_year] filtered_count = len(df_filtered) if filtered_count == 0 and original_count > 0: # Check if filtering removed all data st.warning(f"No data found for the specified year(s): {start_year if start_year else ''}-{end_year if end_year else ''}") elif filtered_count < original_count: st.write(f"Filtered data by year. Rows reduced from {original_count} to {filtered_count}.") return df_filtered # --------------------------------------------------------------------------------- # Load Model # --------------------------------------------------------------------------------- @st.cache_resource def load_gpt2(): try: generator = pipeline('text-generation', model='openai-community/gpt2') return generator except Exception as e: st.error(f"Failed to load GPT-2 model: {e}") return None generator = load_gpt2() # --------------------------------------------------------------------------------- # Load Initial Data # --------------------------------------------------------------------------------- if 'data_labor' not in st.session_state: st.session_state['data_labor'] = load_data("labor") # Or your default table # --------------------------------------------------------------------------------- # Streamlit App UI Starts Here # --------------------------------------------------------------------------------- st.title("Análisis de Datos con GPT-2 y Visualización Automática") # Get the dataframe from session state df = st.session_state.get('data_labor') # --- Check if DataFrame is loaded --- if df is None or df.empty: st.error("Failed to load data or data is empty. Please check Supabase connection and table 'labor'.") # Optionally add a button to retry loading if st.button("Retry Loading Data"): st.session_state['data_labor'] = load_data("labor") st.rerun() # Rerun the script after attempting reload else: # --- Section for the user question --- st.subheader("Pregúntame algo sobre los datos de 'labor'") question = st.text_input("Ejemplo: 'Cuál fue la fuerza laboral (labor force) en Germany entre 2010 y 2015?'") if question: # --- Main processing logic --- st.write("--- Análisis de la pregunta ---") # Debug separator # Filter by Country unique_countries = df['geo'].unique().tolist() if 'geo' in df.columns else [] extracted_country = extract_country_from_prompt_regex(question, unique_countries) filtered_df = df.copy() if extracted_country: if 'geo' in filtered_df.columns: filtered_df = filtered_df[filtered_df['geo'] == extracted_country] st.success(f"Filtrando datos para el país: {extracted_country}") else: st.warning("Columna 'geo' no encontrada para filtrar por país.") else: st.info("No se especificó un país o no se encontró. Mostrando datos para todos los países disponibles.") # Identify Columns numerical_cols = [col for col in filtered_df.columns if pd.api.types.is_numeric_dtype(filtered_df[col])] year_col_names = ['year', 'time', 'period', 'año'] year_cols = [col for col in filtered_df.columns if col.lower() in year_col_names and col in numerical_cols] categorical_cols = [col for col in filtered_df.columns if pd.api.types.is_object_dtype(filtered_df[col]) and col != 'geo'] # Extract Years and Filter DataFrame start_year, end_year = extract_years_from_prompt(question) year_col_to_use = None if year_cols: year_col_to_use = year_cols[0] filtered_df = filter_df_by_years(filtered_df, year_col_to_use, start_year, end_year) else: st.warning("No se pudo identificar una columna de año numérica para filtrar.") # --- GPT-2 Description Generation --- if generator: # Check if model loaded successfully st.subheader("Descripción Automática (GPT-2)") # Create a concise context context_description = "The dataset contains labor data" context_info = f"Data for {extracted_country or 'all countries'}" if extracted_country: # If a specific country is filtered, mention it clearly context_description += f" specifically for {extracted_country}" else: # Otherwise, mention the broader scope if known (e.g., Europe) # If you load data for multiple countries by default, state that context_description += " covering multiple countries" # Adjust if needed if year_col_to_use and (start_year is not None or end_year is not None): context_info += f" between years {start_year if start_year else 'start'} and {end_year if end_year else 'end'}" context_info += f". Columns include: {', '.join(filtered_df.columns.tolist())}." prompt = f"{context_info}\n\nQuestion: {question}\nAnswer based ONLY on the provided context:" try: st.info("Generando descripción...") # Let user know it's working description = generator(prompt, max_new_tokens=200, num_return_sequences=1)[0]['generated_text'] # Clean up the output to show only the answer part answer_part = description.split(prompt)[-1] # Split by the prompt itself st.success("Descripción generada:") st.write(answer_part.strip()) except Exception as e: st.error(f"Error generando descripción con GPT-2: {e}") else: st.warning("El modelo GPT-2 no está cargado. No se puede generar descripción.") # --- Visualization Section --- st.subheader("Visualización Automática") if filtered_df.empty: st.warning("No hay datos para mostrar después de aplicar los filtros.") # --- Logic for LINE PLOT --- elif year_col_to_use and numerical_cols: start_time_graph = time.time() potential_y_cols = [col for col in numerical_cols if col != year_col_to_use] y_col = None if not potential_y_cols: st.warning(f"No se encontraron columnas numéricas de datos (aparte de '{year_col_to_use}') para graficar contra el año.") else: labor_keywords = ['labor', 'labour', 'workforce', 'employment', 'lfpr', 'fuerza'] # Added 'fuerza' found_labor_col = False for col in potential_y_cols: if any(keyword in col.lower() for keyword in labor_keywords): y_col = col st.info(f"Se encontró columna relevante: '{y_col}'. Usándola para el eje Y.") found_labor_col = True break if not found_labor_col: y_col = potential_y_cols[0] st.info(f"No se encontró columna específica. Usando la primera columna numérica disponible ('{y_col}') para el eje Y.") if y_col: x_col = year_col_to_use fig = go.Figure() title = f"{y_col} vs {x_col}" if extracted_country: title += f" en {extracted_country}" if start_year is not None or end_year is not None: year_range_str = "" if start_year is not None: year_range_str += str(start_year) if end_year is not None: year_range_str += f"-{end_year}" if start_year is not None else str(end_year) if year_range_str: title += f" ({year_range_str})" df_plot = filtered_df.sort_values(by=x_col) if y_col in df_plot.columns and x_col in df_plot.columns: # Add color based on 'sex' if available if 'sex' in df_plot.columns: for sex_val in df_plot['sex'].unique(): df_subset = df_plot[df_plot['sex'] == sex_val] fig.add_trace(go.Scatter(x=df_subset[x_col], y=df_subset[y_col], mode='lines+markers', name=str(sex_val))) fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col) else: fig.add_trace(go.Scatter(x=df_plot[x_col], y=df_plot[y_col], mode='lines+markers', name=y_col)) fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col) st.plotly_chart(fig) end_time_graph = time.time() st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos") else: st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.") # --- Logic for SCATTER PLOT --- elif numerical_cols and len(numerical_cols) >= 2: start_time_graph = time.time() st.subheader("Gráfico de Dispersión Sugerido") col1 = st.selectbox("Selecciona la primera columna numérica para el gráfico de dispersión:", numerical_cols) col2 = st.selectbox("Selecciona la segunda columna numérica para el gráfico de dispersión:", [c for c in numerical_cols if c != col1]) if col1 and col2: fig = px.scatter(filtered_df, x=col1, y=col2, title=f"Gráfico de Dispersión: {col1} vs {col2}") st.plotly_chart(fig) end_time_graph = time.time() st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos") else: st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.") # --- Logic for SCATTER PLOT --- # (Your scatter plot logic here...) elif numerical_cols and len(numerical_cols) >= (2 + (1 if year_col_to_use else 0)) : # ... (scatter plot code, ensuring cols exist) ... pass # Placeholder # --- Logic for BAR CHART --- # (Your bar chart logic here...) elif numerical_cols and categorical_cols: # ... (bar chart code, ensuring cols exist and aggregating if needed) ... pass # Placeholder else: # Only show this if no plots were generated above if not (year_col_to_use and y_col): # Check if line plot was attempted st.info("No se encontraron columnas adecuadas o suficientes datos después del filtrado para generar un gráfico automáticamente.")