graph_generator / gpt2test.py
juancamval's picture
Upload 2 files
ab15ee1 verified
raw
history blame contribute delete
16.8 kB
import streamlit as st
import os
import re
import pandas as pd
from dotenv import load_dotenv
from supabase import create_client, Client
from transformers import pipeline
import plotly.express as px
import plotly.graph_objects as go
import time
# ---------------------------------------------------------------------------------
# Supabase Setup
# ---------------------------------------------------------------------------------
load_dotenv()
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# ---------------------------------------------------------------------------------
# Data Loading Function
# ---------------------------------------------------------------------------------
def load_data(table):
try:
if supabase:
response = supabase.from_(table).select("*").execute()
if hasattr(response, 'data'):
return pd.DataFrame(response.data)
else:
st.error(f"Error fetching data or no data returned for table '{table}'. Check Supabase logs.")
return pd.DataFrame()
else:
st.error("Supabase client not initialized.")
return pd.DataFrame()
except Exception as e:
st.error(f"An error occurred during data loading from table '{table}': {e}")
return pd.DataFrame()
# ---------------------------------------------------------------------------------
# Helper Function Definitions
# ---------------------------------------------------------------------------------
def extract_country_from_prompt_regex(question, country_list):
"""Extracts the first matching country from the list found in the question."""
for country in country_list:
# Use word boundaries (\b) for more accurate matching
if re.search(r"\b" + re.escape(country) + r"\b", question, re.IGNORECASE):
return country
return None # Return None if no country in the list is found
def extract_years_from_prompt(question):
"""Extracts a single year or a start/end year range from a question string."""
start_year, end_year = None, None
# Pattern 1: Single year (e.g., "in 2010", "year 2010")
single_year_match = re.search(r'\b(in|year|del)\s+(\d{4})\b', question, re.IGNORECASE)
if single_year_match:
year = int(single_year_match.group(2))
return year, year # Return single year as start and end
# Pattern 2: Year range (e.g., "between 2000 and 2010", "from 2005 to 2015")
range_match = re.search(r'\b(between|from)\s+(\d{4})\s+(and|to)\s+(\d{4})\b', question, re.IGNORECASE)
if range_match:
s_year = int(range_match.group(2))
e_year = int(range_match.group(4))
return min(s_year, e_year), max(s_year, e_year) # Ensure start <= end
# Pattern 3: Simple range like "2000-2010"
simple_range_match = re.search(r'\b(\d{4})-(\d{4})\b', question)
if simple_range_match:
s_year = int(simple_range_match.group(1))
e_year = int(simple_range_match.group(2))
return min(s_year, e_year), max(s_year, e_year)
# Pattern 4: After Year (e.g., "after 2015")
after_match = re.search(r'\b(after|since)\s+(\d{4})\b', question, re.IGNORECASE)
if after_match:
start_year = int(after_match.group(2))
# end_year remains None, signifying >= start_year
# Pattern 5: Before Year (e.g., "before 2005")
before_match = re.search(r'\b(before)\s+(\d{4})\b', question, re.IGNORECASE)
if before_match:
end_year = int(before_match.group(2))
# start_year remains None, signifying <= end_year
# Special case: if 'after' wasn't also found, return (None, end_year)
if start_year is None:
return None, end_year
# Return extracted years (could be None, None; start, None; None, end; or start, end)
# If single year patterns were matched first, they returned already.
return start_year, end_year
def filter_df_by_years(df, year_col, start_year, end_year):
"""Filters a DataFrame based on a year column and a start/end year range."""
if year_col not in df.columns:
st.warning(f"Year column '{year_col}' not found.")
return df
try:
# Ensure year column is numeric, coerce errors to NaT/NaN
df[year_col] = pd.to_numeric(df[year_col], errors='coerce')
# Drop rows where conversion failed, essential for comparison
df_filtered = df.dropna(subset=[year_col]).copy()
# Convert to integer only AFTER dropping NaN, avoids potential float issues
df_filtered[year_col] = df_filtered[year_col].astype(int)
except Exception as e:
st.error(f"Could not convert year column '{year_col}' to numeric: {e}")
return df # Return original on error
original_count = len(df_filtered) # Count after potential NaNs are dropped
if start_year is None and end_year is None:
# No year filtering needed
return df_filtered
st.info(f"Filtering by years: Start={start_year}, End={end_year} on column '{year_col}'")
# Apply filters based on provided start/end years
if start_year is not None and end_year is not None:
# Specific range or single year (where start_year == end_year)
df_filtered = df_filtered[(df_filtered[year_col] >= start_year) & (df_filtered[year_col] <= end_year)]
elif start_year is not None:
# Only start year ("after X")
df_filtered = df_filtered[df_filtered[year_col] >= start_year]
elif end_year is not None:
# Only end year ("before Y")
df_filtered = df_filtered[df_filtered[year_col] <= end_year]
filtered_count = len(df_filtered)
if filtered_count == 0 and original_count > 0: # Check if filtering removed all data
st.warning(f"No data found for the specified year(s): {start_year if start_year else ''}-{end_year if end_year else ''}")
elif filtered_count < original_count:
st.write(f"Filtered data by year. Rows reduced from {original_count} to {filtered_count}.")
return df_filtered
# ---------------------------------------------------------------------------------
# Load Model
# ---------------------------------------------------------------------------------
@st.cache_resource
def load_gpt2():
try:
generator = pipeline('text-generation', model='openai-community/gpt2')
return generator
except Exception as e:
st.error(f"Failed to load GPT-2 model: {e}")
return None
generator = load_gpt2()
# ---------------------------------------------------------------------------------
# Load Initial Data
# ---------------------------------------------------------------------------------
if 'data_labor' not in st.session_state:
st.session_state['data_labor'] = load_data("labor") # Or your default table
# ---------------------------------------------------------------------------------
# Streamlit App UI Starts Here
# ---------------------------------------------------------------------------------
st.title("Análisis de Datos con GPT-2 y Visualización Automática")
# Get the dataframe from session state
df = st.session_state.get('data_labor')
# --- Check if DataFrame is loaded ---
if df is None or df.empty:
st.error("Failed to load data or data is empty. Please check Supabase connection and table 'labor'.")
# Optionally add a button to retry loading
if st.button("Retry Loading Data"):
st.session_state['data_labor'] = load_data("labor")
st.rerun() # Rerun the script after attempting reload
else:
# --- Section for the user question ---
st.subheader("Pregúntame algo sobre los datos de 'labor'")
question = st.text_input("Ejemplo: 'Cuál fue la fuerza laboral (labor force) en Germany entre 2010 y 2015?'")
if question:
# --- Main processing logic ---
st.write("--- Análisis de la pregunta ---") # Debug separator
# Filter by Country
unique_countries = df['geo'].unique().tolist() if 'geo' in df.columns else []
extracted_country = extract_country_from_prompt_regex(question, unique_countries)
filtered_df = df.copy()
if extracted_country:
if 'geo' in filtered_df.columns:
filtered_df = filtered_df[filtered_df['geo'] == extracted_country]
st.success(f"Filtrando datos para el país: {extracted_country}")
else:
st.warning("Columna 'geo' no encontrada para filtrar por país.")
else:
st.info("No se especificó un país o no se encontró. Mostrando datos para todos los países disponibles.")
# Identify Columns
numerical_cols = [col for col in filtered_df.columns if pd.api.types.is_numeric_dtype(filtered_df[col])]
year_col_names = ['year', 'time', 'period', 'año']
year_cols = [col for col in filtered_df.columns if col.lower() in year_col_names and col in numerical_cols]
categorical_cols = [col for col in filtered_df.columns if pd.api.types.is_object_dtype(filtered_df[col]) and col != 'geo']
# Extract Years and Filter DataFrame
start_year, end_year = extract_years_from_prompt(question)
year_col_to_use = None
if year_cols:
year_col_to_use = year_cols[0]
filtered_df = filter_df_by_years(filtered_df, year_col_to_use, start_year, end_year)
else:
st.warning("No se pudo identificar una columna de año numérica para filtrar.")
# --- GPT-2 Description Generation ---
if generator: # Check if model loaded successfully
st.subheader("Descripción Automática (GPT-2)")
# Create a concise context
context_description = "The dataset contains labor data"
context_info = f"Data for {extracted_country or 'all countries'}"
if extracted_country:
# If a specific country is filtered, mention it clearly
context_description += f" specifically for {extracted_country}"
else:
# Otherwise, mention the broader scope if known (e.g., Europe)
# If you load data for multiple countries by default, state that
context_description += " covering multiple countries" # Adjust if needed
if year_col_to_use and (start_year is not None or end_year is not None):
context_info += f" between years {start_year if start_year else 'start'} and {end_year if end_year else 'end'}"
context_info += f". Columns include: {', '.join(filtered_df.columns.tolist())}."
prompt = f"{context_info}\n\nQuestion: {question}\nAnswer based ONLY on the provided context:"
try:
st.info("Generando descripción...") # Let user know it's working
description = generator(prompt, max_new_tokens=200, num_return_sequences=1)[0]['generated_text']
# Clean up the output to show only the answer part
answer_part = description.split(prompt)[-1] # Split by the prompt itself
st.success("Descripción generada:")
st.write(answer_part.strip())
except Exception as e:
st.error(f"Error generando descripción con GPT-2: {e}")
else:
st.warning("El modelo GPT-2 no está cargado. No se puede generar descripción.")
# --- Visualization Section ---
st.subheader("Visualización Automática")
if filtered_df.empty:
st.warning("No hay datos para mostrar después de aplicar los filtros.")
# --- Logic for LINE PLOT ---
elif year_col_to_use and numerical_cols:
start_time_graph = time.time()
potential_y_cols = [col for col in numerical_cols if col != year_col_to_use]
y_col = None
if not potential_y_cols:
st.warning(f"No se encontraron columnas numéricas de datos (aparte de '{year_col_to_use}') para graficar contra el año.")
else:
labor_keywords = ['labor', 'labour', 'workforce', 'employment', 'lfpr', 'fuerza'] # Added 'fuerza'
found_labor_col = False
for col in potential_y_cols:
if any(keyword in col.lower() for keyword in labor_keywords):
y_col = col
st.info(f"Se encontró columna relevante: '{y_col}'. Usándola para el eje Y.")
found_labor_col = True
break
if not found_labor_col:
y_col = potential_y_cols[0]
st.info(f"No se encontró columna específica. Usando la primera columna numérica disponible ('{y_col}') para el eje Y.")
if y_col:
x_col = year_col_to_use
fig = go.Figure()
title = f"{y_col} vs {x_col}"
if extracted_country:
title += f" en {extracted_country}"
if start_year is not None or end_year is not None:
year_range_str = ""
if start_year is not None:
year_range_str += str(start_year)
if end_year is not None:
year_range_str += f"-{end_year}" if start_year is not None else str(end_year)
if year_range_str:
title += f" ({year_range_str})"
df_plot = filtered_df.sort_values(by=x_col)
if y_col in df_plot.columns and x_col in df_plot.columns:
# Add color based on 'sex' if available
if 'sex' in df_plot.columns:
for sex_val in df_plot['sex'].unique():
df_subset = df_plot[df_plot['sex'] == sex_val]
fig.add_trace(go.Scatter(x=df_subset[x_col], y=df_subset[y_col], mode='lines+markers', name=str(sex_val)))
fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col)
else:
fig.add_trace(go.Scatter(x=df_plot[x_col], y=df_plot[y_col], mode='lines+markers', name=y_col))
fig.update_layout(title=title, xaxis_title=x_col, yaxis_title=y_col)
st.plotly_chart(fig)
end_time_graph = time.time()
st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos")
else:
st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.")
# --- Logic for SCATTER PLOT ---
elif numerical_cols and len(numerical_cols) >= 2:
start_time_graph = time.time()
st.subheader("Gráfico de Dispersión Sugerido")
col1 = st.selectbox("Selecciona la primera columna numérica para el gráfico de dispersión:", numerical_cols)
col2 = st.selectbox("Selecciona la segunda columna numérica para el gráfico de dispersión:", [c for c in numerical_cols if c != col1])
if col1 and col2:
fig = px.scatter(filtered_df, x=col1, y=col2, title=f"Gráfico de Dispersión: {col1} vs {col2}")
st.plotly_chart(fig)
end_time_graph = time.time()
st.write(f"Gráfico generado en: {end_time_graph - start_time_graph:.4f} segundos")
else:
st.warning("Las columnas X o Y seleccionadas no existen en los datos filtrados.")
# --- Logic for SCATTER PLOT ---
# (Your scatter plot logic here...)
elif numerical_cols and len(numerical_cols) >= (2 + (1 if year_col_to_use else 0)) :
# ... (scatter plot code, ensuring cols exist) ...
pass # Placeholder
# --- Logic for BAR CHART ---
# (Your bar chart logic here...)
elif numerical_cols and categorical_cols:
# ... (bar chart code, ensuring cols exist and aggregating if needed) ...
pass # Placeholder
else:
# Only show this if no plots were generated above
if not (year_col_to_use and y_col): # Check if line plot was attempted
st.info("No se encontraron columnas adecuadas o suficientes datos después del filtrado para generar un gráfico automáticamente.")