Spaces:
Running
Running
import streamlit as st | |
import os | |
import time | |
import pandas as pd | |
from dotenv import load_dotenv | |
from supabase import create_client, Client | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
import plotly.graph_objects as go | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
import re | |
# --------------------------------------------------------------------------------- | |
# Funciones auxiliares | |
# --------------------------------------------------------------------------------- | |
def extract_country_and_dates(prompt, countries): | |
country = None | |
start_date = None | |
end_date = None | |
# Buscar el pa铆s (insensible a may煤sculas y min煤sculas) | |
for c in countries: | |
if re.search(r'\b' + re.escape(c) + r'\b', prompt, re.IGNORECASE): | |
country = c | |
break | |
# Buscar rangos de a帽os con diferentes separadores (-, to, until, from ... to, between ... and) | |
date_ranges = re.findall(r'(\d{4})\s*(?:-|to|until|from.*?to|between.*?and)\s*(\d{4})', prompt, re.IGNORECASE) | |
if date_ranges: | |
start_date = date_ranges[0][0] | |
end_date = date_ranges[0][1] | |
else: | |
# Buscar un solo a帽o | |
single_years = re.findall(r'\b(\d{4})\b', prompt) | |
if single_years: | |
start_date = single_years[0] | |
end_date = single_years[0] | |
return country, start_date, end_date | |
def generate_plotly_graph(df, user_query, country=None, start_date=None, end_date=None): | |
relevant_data = df.copy() | |
if 'geo' in relevant_data.columns and country: | |
relevant_data = relevant_data[relevant_data['geo'].str.lower() == country.lower()] | |
if 'year' in relevant_data.columns: | |
relevant_data['year'] = pd.to_numeric(relevant_data['year'], errors='coerce').dropna().astype(int) | |
if start_date and end_date: | |
relevant_data = relevant_data[ | |
(relevant_data['year'] >= int(start_date)) & (relevant_data['year'] <= int(end_date)) | |
] | |
elif start_date: | |
relevant_data = relevant_data[relevant_data['year'] >= int(start_date)] | |
elif end_date: | |
relevant_data = relevant_data[relevant_data['year'] <= int(end_date)] | |
numeric_cols = relevant_data.select_dtypes(include=['number']).columns.tolist() | |
if 'year' in relevant_data.columns and numeric_cols: | |
fig = go.Figure() | |
for col in numeric_cols: | |
if col != 'year': | |
fig.add_trace(go.Scatter(x=relevant_data['year'], y=relevant_data[col], mode='lines+markers', name=col)) | |
title = f"Data for {country if country else 'All Regions'}" | |
if start_date and end_date: | |
title += f" ({start_date}-{end_date})" | |
elif start_date: | |
title += f" (from {start_date})" | |
elif end_date: | |
title += f" (up to {end_date})" | |
# A帽adir t铆tulo y etiquetas de los ejes | |
fig.update_layout( | |
title=title, | |
xaxis_title="Year", | |
yaxis_title="Value" # Necesitaremos inferir o tener nombres de columnas m谩s descriptivos | |
) | |
return fig | |
else: | |
return None | |
# --------------------------------------------------------------------------------- | |
# Configuraci贸n de conexi贸n a Supabase | |
# --------------------------------------------------------------------------------- | |
load_dotenv() | |
SUPABASE_URL = os.getenv("SUPABASE_URL") | |
SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
# Funci贸n para cargar datos de una tabla de Supabase | |
def load_data(table): | |
try: | |
if supabase: | |
response = supabase.from_(table).select("*").execute() | |
if hasattr(response, 'data'): | |
return pd.DataFrame(response.data) | |
elif hasattr(response, '_error'): | |
st.error(f"Error fetching data: {response._error}") | |
return pd.DataFrame() | |
else: | |
st.info("Response object does not have 'data' or known error attributes. Check the logs.") | |
return pd.DataFrame() | |
else: | |
st.error("Supabase client not initialized. Check environment variables.") | |
return pd.DataFrame() | |
except Exception as e: | |
st.error(f"An error occurred during data loading: {e}") | |
return pd.DataFrame() | |
# --------------------------------------------------------------------------------- | |
# Cargar datos iniciales | |
# --------------------------------------------------------------------------------- | |
labor_data = load_data("labor") | |
fertility_data = load_data("fertility") | |
# --------------------------------------------------------------------------------- | |
# Inicializaci贸n de modelos para RAG | |
# --------------------------------------------------------------------------------- | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
llm_pipeline = pipeline("text-generation", model="google/gemma-3-1b-it", token=os.getenv("HF_TOKEN")) | |
# --------------------------------------------------------------------------------- | |
# Generaci贸n de Embeddings y Metadatos (en memoria) | |
# --------------------------------------------------------------------------------- | |
embeddings_list = [] | |
contents_list = [] | |
metadatas_list = [] | |
ids_list = [] | |
for index, row in labor_data.iterrows(): | |
doc = f"Country: {row['geo']}, Year: {row['year']}, Employment Rate: {row['labour_force'] if 'labour_force' in row else 'N/A'}" | |
embeddings_list.append(embedding_model.encode(doc)) | |
contents_list.append(doc) | |
metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'labor'}) | |
ids_list.append(f"labor_{index}") | |
for index, row in fertility_data.iterrows(): | |
doc = f"Country: {row['geo']}, Year: {row['year']}, Fertility Rate: {row['fertility_rate'] if 'fertility_rate' in row else 'N/A'}" | |
embeddings_list.append(embedding_model.encode(doc)) | |
contents_list.append(doc) | |
metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'fertility'}) | |
ids_list.append(f"fertility_{index}") | |
embeddings_array = np.array(embeddings_list) | |
# --------------------------------------------------------------------------------- | |
# Funci贸n para recuperar documentos relevantes (en memoria) | |
# --------------------------------------------------------------------------------- | |
def retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents, top_k=3): | |
similarities = cosine_similarity([query_embedding], stored_embeddings)[0] | |
sorted_indices = np.argsort(similarities)[::-1] | |
relevant_documents = [contents[i] for i in sorted_indices[:top_k]] | |
return relevant_documents | |
# --------------------------------------------------------------------------------- | |
# Generaci贸n de la explicaci贸n usando RAG | |
# --------------------------------------------------------------------------------- | |
def generate_rag_explanation(user_query, stored_embeddings, contents): | |
query_embedding = embedding_model.encode(user_query) | |
relevant_docs = retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents) | |
if relevant_docs: | |
context = "\n".join(relevant_docs) | |
augmented_prompt = f"Based on the following information:\n\n{context}\n\nAnswer the question related to: {user_query}" | |
output = llm_pipeline(augmented_prompt, max_length=250, num_return_sequences=1) | |
return output[0]['generated_text'] | |
else: | |
return "No relevant information found to answer your query." | |
# --------------------------------------------------------------------------------- | |
# Generar la lista de pa铆ses autom谩ticamente | |
# --------------------------------------------------------------------------------- | |
available_countries_labor = labor_data['geo'].unique().tolist() if 'geo' in labor_data.columns else [] | |
available_countries_fertility = fertility_data['geo'].unique().tolist() if 'geo' in fertility_data.columns else [] | |
all_countries = list(set(available_countries_labor + available_countries_fertility)) | |
# --------------------------------------------------------------------------------- | |
# Configuraci贸n de la app en Streamlit | |
# --------------------------------------------------------------------------------- | |
st.set_page_config(page_title="GraphGen", page_icon="馃嚜馃嚭") | |
st.title("_Europe GraphGen_ 聽:blue[Graph generator] :flag-eu:") | |
st.caption("Mapping Europe's data with insights") | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.session_state.messages.append({"role": "assistant", "content": "What graphic and insights do you need?"}) | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
prompt = st.chat_input("Type your message here...", key="chat_input_bottom") | |
if prompt: | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.spinner('Generating answer...'): | |
try: | |
# Determinar el a帽o m谩s reciente en los datos | |
latest_year_labor = labor_data['year'].max() if 'year' in labor_data else datetime.now().year | |
latest_year_fertility = fertility_data['year'].max() if 'year' in fertility_data else datetime.now().year | |
latest_year = max(latest_year_labor, latest_year_fertility, datetime.now().year) | |
country, start_date, end_date = extract_country_and_dates(prompt, all_countries, latest_year) | |
graph_displayed = False | |
# Analizar el prompt para determinar la intenci贸n del usuario | |
if re.search(r'\b(labor|employment|job|workforce)\b', prompt, re.IGNORECASE): | |
# Generar gr谩fica de datos laborales | |
labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date) | |
if labor_fig: | |
st.session_state.messages.append( | |
{"role": "assistant", "content": "Here is the labor data graphic:"}) | |
with st.chat_message("assistant"): | |
st.plotly_chart(labor_fig) | |
graph_displayed = True | |
elif re.search(r'\b(fertility|birth|population growth)\b', prompt, re.IGNORECASE): | |
# Generar gr谩fica de datos de fertilidad | |
fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date) | |
if fertility_fig: | |
st.session_state.messages.append( | |
{"role": "assistant", "content": "Here is the fertility data graphic:"}) | |
with st.chat_message("assistant"): | |
st.plotly_chart(fertility_fig) | |
graph_displayed = True | |
else: | |
# Si no se identifica una intenci贸n clara, intentar mostrar la gr谩fica de datos laborales primero | |
labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date) | |
if labor_fig: | |
st.session_state.messages.append( | |
{"role": "assistant", "content": "Here is the labor data graphic:"}) | |
with st.chat_message("assistant"): | |
st.plotly_chart(labor_fig) | |
graph_displayed = True | |
elif not graph_displayed: | |
fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date) | |
if fertility_fig: | |
st.session_state.messages.append( | |
{"role": "assistant", "content": "Here is the fertility data graphic:"}) | |
with st.chat_message("assistant"): | |
st.plotly_chart(fertility_fig) | |
graph_displayed = True | |
# Generar explicaci贸n usando RAG | |
explanation = generate_rag_explanation(prompt, embeddings_array, contents_list) | |
st.session_state.messages.append({"role": "assistant", "content": f"Explanation: {explanation}"}) | |
with st.chat_message("assistant"): | |
st.markdown(f"**Explanation:** {explanation}") | |
except Exception as e: | |
st.session_state.messages.append({"role": "assistant", "content": f"Error generating answer: {e}"}) | |
with st.chat_message("assistant"): | |
st.error(f"Error generating answer: {e}") | |
if st.button("Clear chat"): | |
st.session_state.messages = [] | |
st.session_state.messages.append( | |
{"role": "assistant", "content": "Chat has been cleared. What graphic and insights do you need now?"}) | |
st.rerun() |