Spaces:
Running
Running
import streamlit as st | |
import os | |
import re | |
import pandas as pd | |
from dotenv import load_dotenv | |
from supabase import create_client, Client | |
from transformers import pipeline | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import time | |
# --------------------------------------------------------------------------------- | |
# Supabase Setup | |
# --------------------------------------------------------------------------------- | |
load_dotenv() | |
SUPABASE_URL = os.getenv("SUPABASE_URL") | |
SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
# --------------------------------------------------------------------------------- | |
# Data Loading Function | |
# --------------------------------------------------------------------------------- | |
def load_data(table): | |
try: | |
if supabase: | |
response = supabase.from_(table).select("*").execute() | |
if hasattr(response, 'data'): | |
return pd.DataFrame(response.data) | |
else: | |
st.error(f"Error fetching data or no data returned for table '{table}'. Check Supabase logs.") | |
return pd.DataFrame() | |
else: | |
st.error("Supabase client not initialized.") | |
return pd.DataFrame() | |
except Exception as e: | |
st.error(f"An error occurred during data loading from table '{table}': {e}") | |
return pd.DataFrame() | |
# --------------------------------------------------------------------------------- | |
# Helper Function Definitions | |
# --------------------------------------------------------------------------------- | |
def extract_country_from_prompt_regex(question, country_list): | |
"""Extracts the first matching country from the list found in the question.""" | |
for country in country_list: | |
# Use word boundaries (\b) for more accurate matching | |
if re.search(r"\b" + re.escape(country) + r"\b", question, re.IGNORECASE): | |
return country | |
return None # Return None if no country in the list is found | |
def extract_years_from_prompt(question): | |
"""Extracts a single year or a start/end year range from a question string.""" | |
start_year, end_year = None, None | |
# Pattern 1: Single year (e.g., "in 2010", "year 2010") | |
single_year_match = re.search(r'\b(in|year|del)\s+(\d{4})\b', question, re.IGNORECASE) | |
if single_year_match: | |
year = int(single_year_match.group(2)) | |
return year, year # Return single year as start and end | |
# Pattern 2: Year range (e.g., "between 2000 and 2010", "from 2005 to 2015") | |
range_match = re.search(r'\b(between|from)\s+(\d{4})\s+(and|to)\s+(\d{4})\b', question, re.IGNORECASE) | |
if range_match: | |
s_year = int(range_match.group(2)) | |
e_year = int(range_match.group(4)) | |
return min(s_year, e_year), max(s_year, e_year) # Ensure start <= end | |
# Pattern 3: Simple range like "2000-2010" | |
simple_range_match = re.search(r'\b(\d{4})-(\d{4})\b', question) | |
if simple_range_match: | |
s_year = int(simple_range_match.group(1)) | |
e_year = int(simple_range_match.group(2)) | |
return min(s_year, e_year), max(s_year, e_year) | |
# Pattern 4: After Year (e.g., "after 2015") | |
after_match = re.search(r'\b(after|since)\s+(\d{4})\b', question, re.IGNORECASE) | |
if after_match: | |
start_year = int(after_match.group(2)) | |
# end_year remains None, signifying >= start_year | |
# Pattern 5: Before Year (e.g., "before 2005") | |
before_match = re.search(r'\b(before)\s+(\d{4})\b', question, re.IGNORECASE) | |
if before_match: | |
end_year = int(before_match.group(2)) | |
# start_year remains None, signifying <= end_year | |
# Special case: if 'after' wasn't also found, return (None, end_year) | |
if start_year is None: | |
return None, end_year | |
# Return extracted years (could be None, None; start, None; None, end; or start, end) | |
# If single year patterns were matched first, they returned already. | |
return start_year, end_year | |
def filter_df_by_years(df, year_col, start_year, end_year): | |
"""Filters a DataFrame based on a year column and a start/end year range.""" | |
if year_col not in df.columns: | |
st.warning(f"Year column '{year_col}' not found.") | |
return df | |
try: | |
# Ensure year column is numeric, coerce errors to NaT/NaN | |
df[year_col] = pd.to_numeric(df[year_col], errors='coerce') | |
# Drop rows where conversion failed, essential for comparison | |
df_filtered = df.dropna(subset=[year_col]).copy() | |
# Convert to integer only AFTER dropping NaN, avoids potential float issues | |
df_filtered[year_col] = df_filtered[year_col].astype(int) | |
except Exception as e: | |
st.error(f"Could not convert year column '{year_col}' to numeric: {e}") | |
return df # Return original on error | |
original_count = len(df_filtered) # Count after potential NaNs are dropped | |
if start_year is None and end_year is None: | |
# No year filtering needed | |
return df_filtered | |
st.info(f"Filtering by years: Start={start_year}, End={end_year} on column '{year_col}'") | |
# Apply filters based on provided start/end years | |
if start_year is not None and end_year is not None: | |
# Specific range or single year (where start_year == end_year) | |
df_filtered = df_filtered[(df_filtered[year_col] >= start_year) & (df_filtered[year_col] <= end_year)] | |
elif start_year is not None: | |
# Only start year ("after X") | |
df_filtered = df_filtered[df_filtered[year_col] >= start_year] | |
elif end_year is not None: | |
# Only end year ("before Y") | |
df_filtered = df_filtered[df_filtered[year_col] <= end_year] | |
filtered_count = len(df_filtered) | |
if filtered_count == 0 and original_count > 0: # Check if filtering removed all data | |
st.warning(f"No data found for the specified year(s): {start_year if start_year else ''}-{end_year if end_year else ''}") | |
elif filtered_count < original_count: | |
st.write(f"Filtered data by year. Rows reduced from {original_count} to {filtered_count}.") | |
return df_filtered | |
# --------------------------------------------------------------------------------- | |
# Load Model | |
# --------------------------------------------------------------------------------- | |
def load_gpt2(): | |
try: | |
generator = pipeline('text-generation', model='openai-community/gpt2') | |
return generator | |
except Exception as e: | |
st.error(f"Failed to load GPT-2 model: {e}") | |
return None | |
generator = load_gpt2() | |
# --------------------------------------------------------------------------------- | |
# Load Initial Data | |
# --------------------------------------------------------------------------------- | |
if 'data_labor' not in st.session_state: | |
st.session_state['data_labor'] = load_data("labor") # Or your default table | |
# --------------------------------------------------------------------------------- | |
# Streamlit App UI Starts Here | |
# --------------------------------------------------------------------------------- | |
st.title("Análisis de Datos con GPT-2 y Visualización Automática") | |
# Get the dataframe from session state | |
df = st.session_state.get('data_labor') | |
# --- Check if DataFrame is loaded --- | |
if df is None or df.empty: | |
st.error("Failed to load data or data is empty. Please check Supabase connection and table 'labor'.") | |
# Optionally add a button to retry loading | |
if st.button("Retry Loading Data"): | |
st.session_state['data_labor'] = load_data("labor") | |
st.rerun() # Rerun the script after attempting reload | |
else: | |
# --- Section for the user question --- | |
st.subheader("Pregúntame algo sobre los datos de 'labor'") | |
question = st.text_input("Ejemplo: 'Cuál fue la fuerza laboral (labor force) en Germany entre 2010 y 2015?'") | |
if question: | |
# --- Main processing logic --- | |
st.write("--- Análisis de la pregunta ---") # Debug separator | |
# Filter by Country | |
unique_countries = df['geo'].unique().tolist() if 'geo' in df.columns else [] | |
extracted_country = extract_country_from_prompt_regex(question, unique_countries) | |
filtered_df = df.copy() | |
if extracted_country: | |
if 'geo' in filtered_df.columns: | |
filtered_df = filtered_df[filtered_df['geo'] == extracted_country] | |
st.success(f"Filtrando datos para el país: {extracted_country}") | |
else: | |
st.warning("Columna 'geo' no encontrada para filtrar por país.") | |
else: | |
st.info("No se especificó un país o no se encontró. Mostrando datos para todos los países disponibles.") | |
# Identify Columns | |
numerical_cols = [col for col in filtered_df.columns if pd.api.types.is_numeric_dtype(filtered_df[col])] | |
year_col_names = ['year', 'time', 'period', 'año'] | |
year_cols = [col for col in filtered_df.columns if col.lower() in year_col_names and col in numerical_cols] | |
categorical_cols = [col for col in filtered_df.columns if pd.api.types.is_object_dtype(filtered_df[col]) and col != 'geo'] | |
# Extract Years and Filter DataFrame | |
start_year, end_year = extract_years_from_prompt(question) | |
year_col_to_use = None | |
if year_cols: | |
year_col_to_use = year_cols[0] | |
filtered_df = filter_df_by_years(filtered_df, year_col_to_use, start_year, end_year) | |
else: | |
st.warning("No se pudo identificar una columna de año numérica para filtrar.") | |
# --- GPT-2 Description Generation --- | |
if generator: # Check if model loaded successfully | |
st.subheader("Descripción Automática (GPT-2)") | |
# Create a concise context | |
context_description = "The dataset contains labor data" | |
context_info = f"Data for {extracted_country or 'all countries'}" | |
if extracted_country: | |
# If a specific country is filtered, mention it clearly | |
context_description += f" specifically for {extracted_country}" | |
else: | |
# Otherwise, mention the broader scope if known (e.g., Europe) | |
# If you load data for multiple countries by default, state that | |
context_description += " covering multiple countries" # Adjust if needed | |
if year_col_to_use and (start_year is not None or end_year is not None): | |
context_info += f" between years {start_year if start_year else 'start'} and {end_year if end_year else 'end'}" | |
context_info += f". Columns include: {', '.join(filtered_df.columns.tolist())}." | |
prompt = f"{context_info}\n\nQuestion: {question}\nAnswer based ONLY on the provided context:" | |
try: | |
st.info("Generando descripción...") # Let user know it's working | |
description = generator(prompt, max_new_tokens=200, num_return_sequences=1)[0]['generated_text'] | |
# Clean up the output to show only the answer part | |
answer_part = description.split(prompt)[-1] # Split by the prompt itself | |
st.success("Descripción generada:") | |
st.write(answer_part.strip()) | |
except Exception as e: | |
st.error(f"Error generando descripción con GPT-2: {e}") | |
else: | |
st.warning("El modelo GPT-2 no está cargado. No se puede generar descripción.") | |
# --- Visualization Section --- | |
st.subheader("Visualización Automática") | |
if filtered_df.empty: | |
st.warning("No hay datos para mostrar después de aplicar los filtros.") | |
# --- Logic for LINE PLOT --- | |
elif year_col_to_use and numerical_cols: | |
start_time_graph = time.time() | |
potential_y_cols = [col for col in numerical_cols if col != year_col_to_use] | |
y_col = None | |
if not potential_y_cols: | |
st.warning(f"No se encontraron columnas numéricas de datos (aparte de '{year_col_to_use}') para graficar contra el año.") | |
else: | |
labor_keywords = ['labor', 'labour', 'workforce', 'employment', 'lfpr', 'fuerza'] # Added 'fuerza' | |
found_labor_col = False | |
for col in potential_y_cols: | |
if any(keyword in col.lower() for keyword in labor_keywords): | |
y_col = col | |
st.info(f"Se encontró columna relevante: '{y_col}'. Usándola para el eje Y.") | |
found_labor_col = True | |
break | |
if not found_labor_col: | |
y_col = potential_y_cols[0] | |
st.info(f"No se encontró columna específica. Usando la primera columna numérica disponible ('{y_col}') para el eje Y.") | |
if y_col: | |
x_col = year_col_to_use | |
fig = go.Figure() | |
title = f"{y_col} vs {x_col}" | |
if extracted_country: | |
title += f" en {extracted_country}" | |
if start_year is not None or end_year is not None: | |
year_range_str = "" | |
if start_year is not None: | |
year_range_str += str(start_year) | |
if end_year is not None: | |
year_range_str += f"-{end_year}" if start_year is not None else str(end_year) | |
if year_range_str: | |
title += f" ({year_range_str})" | |
df_plot = filtered_df.sort_values(by=x_col) | |
if y_col in df_plot.columns and x_col in df_plot.columns: | |
# Add color based on 'sex' if available | |
if 'sex' in df_plot.columns: | |
for sex_val in df_plot['sex'].unique(): | |
df_subset = df_plot[df_plot['sex'] == sex_val] | |
fig.add_trace(go.Scatter(x=df_subset[x_col], y=df_subset[y_col], mode='lines+markers', name=str(sex_val))) | |
fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col) | |
else: | |
fig.add_trace(go.Scatter(x=df_plot[x_col], y=df_plot[y_col], mode='lines+markers', name=y_col)) | |
fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col) | |
st.plotly_chart(fig) | |
end_time_graph = time.time() | |
st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos") | |
else: | |
st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.") | |
# --- Logic for SCATTER PLOT --- | |
elif numerical_cols and len(numerical_cols) >= 2: | |
start_time_graph = time.time() | |
st.subheader("Gráfico de Dispersión Sugerido") | |
col1 = st.selectbox("Selecciona la primera columna numérica para el gráfico de dispersión:", numerical_cols) | |
col2 = st.selectbox("Selecciona la segunda columna numérica para el gráfico de dispersión:", [c for c in numerical_cols if c != col1]) | |
if col1 and col2: | |
fig = px.scatter(filtered_df, x=col1, y=col2, title=f"Gráfico de Dispersión: {col1} vs {col2}") | |
st.plotly_chart(fig) | |
end_time_graph = time.time() | |
st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos") | |
else: | |
st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.") | |
# --- Logic for SCATTER PLOT --- | |
# (Your scatter plot logic here...) | |
elif numerical_cols and len(numerical_cols) >= (2 + (1 if year_col_to_use else 0)) : | |
# ... (scatter plot code, ensuring cols exist) ... | |
pass # Placeholder | |
# --- Logic for BAR CHART --- | |
# (Your bar chart logic here...) | |
elif numerical_cols and categorical_cols: | |
# ... (bar chart code, ensuring cols exist and aggregating if needed) ... | |
pass # Placeholder | |
else: | |
# Only show this if no plots were generated above | |
if not (year_col_to_use and y_col): # Check if line plot was attempted | |
st.info("No se encontraron columnas adecuadas o suficientes datos después del filtrado para generar un gráfico automáticamente.") |