import streamlit as st import os import time import pandas as pd from dotenv import load_dotenv from supabase import create_client, Client from transformers import pipeline from sentence_transformers import SentenceTransformer import plotly.graph_objects as go import numpy as np from sklearn.metrics.pairwise import cosine_similarity import re # --------------------------------------------------------------------------------- # Funciones auxiliares # --------------------------------------------------------------------------------- def extract_country_and_dates(prompt, countries): country = None start_date = None end_date = None # Buscar el país (insensible a mayúsculas y minúsculas) for c in countries: if re.search(r'\b' + re.escape(c) + r'\b', prompt, re.IGNORECASE): country = c break # Buscar rangos de años con diferentes separadores (-, to, until, from ... to, between ... and) date_ranges = re.findall(r'(\d{4})\s*(?:-|to|until|from.*?to|between.*?and)\s*(\d{4})', prompt, re.IGNORECASE) if date_ranges: start_date = date_ranges[0][0] end_date = date_ranges[0][1] else: # Buscar un solo año single_years = re.findall(r'\b(\d{4})\b', prompt) if single_years: start_date = single_years[0] end_date = single_years[0] return country, start_date, end_date def generate_plotly_graph(df, user_query, country=None, start_date=None, end_date=None): relevant_data = df.copy() if 'geo' in relevant_data.columns and country: relevant_data = relevant_data[relevant_data['geo'].str.lower() == country.lower()] if 'year' in relevant_data.columns: relevant_data['year'] = pd.to_numeric(relevant_data['year'], errors='coerce').dropna().astype(int) if start_date and end_date: relevant_data = relevant_data[ (relevant_data['year'] >= int(start_date)) & (relevant_data['year'] <= int(end_date)) ] elif start_date: relevant_data = relevant_data[relevant_data['year'] >= int(start_date)] elif end_date: relevant_data = relevant_data[relevant_data['year'] <= int(end_date)] numeric_cols = relevant_data.select_dtypes(include=['number']).columns.tolist() if 'year' in relevant_data.columns and numeric_cols: fig = go.Figure() for col in numeric_cols: if col != 'year': fig.add_trace(go.Scatter(x=relevant_data['year'], y=relevant_data[col], mode='lines+markers', name=col)) title = f"Data for {country if country else 'All Regions'}" if start_date and end_date: title += f" ({start_date}-{end_date})" elif start_date: title += f" (from {start_date})" elif end_date: title += f" (up to {end_date})" # Añadir título y etiquetas de los ejes fig.update_layout( title=title, xaxis_title="Year", yaxis_title="Value" # Necesitaremos inferir o tener nombres de columnas más descriptivos ) return fig else: return None # --------------------------------------------------------------------------------- # Configuración de conexión a Supabase # --------------------------------------------------------------------------------- load_dotenv() SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_KEY") supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) # Función para cargar datos de una tabla de Supabase def load_data(table): try: if supabase: response = supabase.from_(table).select("*").execute() if hasattr(response, 'data'): return pd.DataFrame(response.data) elif hasattr(response, '_error'): st.error(f"Error fetching data: {response._error}") return pd.DataFrame() else: st.info("Response object does not have 'data' or known error attributes. Check the logs.") return pd.DataFrame() else: st.error("Supabase client not initialized. Check environment variables.") return pd.DataFrame() except Exception as e: st.error(f"An error occurred during data loading: {e}") return pd.DataFrame() # --------------------------------------------------------------------------------- # Cargar datos iniciales # --------------------------------------------------------------------------------- labor_data = load_data("labor") fertility_data = load_data("fertility") # --------------------------------------------------------------------------------- # Inicialización de modelos para RAG # --------------------------------------------------------------------------------- embedding_model = SentenceTransformer('all-MiniLM-L6-v2') llm_pipeline = pipeline("text-generation", model="google/gemma-3-1b-it", token=os.getenv("HF_TOKEN")) # --------------------------------------------------------------------------------- # Generación de Embeddings y Metadatos (en memoria) # --------------------------------------------------------------------------------- embeddings_list = [] contents_list = [] metadatas_list = [] ids_list = [] for index, row in labor_data.iterrows(): doc = f"Country: {row['geo']}, Year: {row['year']}, Employment Rate: {row['labour_force'] if 'labour_force' in row else 'N/A'}" embeddings_list.append(embedding_model.encode(doc)) contents_list.append(doc) metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'labor'}) ids_list.append(f"labor_{index}") for index, row in fertility_data.iterrows(): doc = f"Country: {row['geo']}, Year: {row['year']}, Fertility Rate: {row['fertility_rate'] if 'fertility_rate' in row else 'N/A'}" embeddings_list.append(embedding_model.encode(doc)) contents_list.append(doc) metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'fertility'}) ids_list.append(f"fertility_{index}") embeddings_array = np.array(embeddings_list) # --------------------------------------------------------------------------------- # Función para recuperar documentos relevantes (en memoria) # --------------------------------------------------------------------------------- def retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents, top_k=3): similarities = cosine_similarity([query_embedding], stored_embeddings)[0] sorted_indices = np.argsort(similarities)[::-1] relevant_documents = [contents[i] for i in sorted_indices[:top_k]] return relevant_documents # --------------------------------------------------------------------------------- # Generación de la explicación usando RAG # --------------------------------------------------------------------------------- def generate_rag_explanation(user_query, stored_embeddings, contents): query_embedding = embedding_model.encode(user_query) relevant_docs = retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents) if relevant_docs: context = "\n".join(relevant_docs) augmented_prompt = f"Based on the following information:\n\n{context}\n\nAnswer the question related to: {user_query}" output = llm_pipeline(augmented_prompt, max_length=250, num_return_sequences=1) return output[0]['generated_text'] else: return "No relevant information found to answer your query." # --------------------------------------------------------------------------------- # Generar la lista de países automáticamente # --------------------------------------------------------------------------------- available_countries_labor = labor_data['geo'].unique().tolist() if 'geo' in labor_data.columns else [] available_countries_fertility = fertility_data['geo'].unique().tolist() if 'geo' in fertility_data.columns else [] all_countries = list(set(available_countries_labor + available_countries_fertility)) # --------------------------------------------------------------------------------- # Configuración de la app en Streamlit # --------------------------------------------------------------------------------- st.set_page_config(page_title="GraphGen", page_icon="🇪🇺") st.title("_Europe GraphGen_  :blue[Graph generator] :flag-eu:") st.caption("Mapping Europe's data with insights") if "messages" not in st.session_state: st.session_state.messages = [] st.session_state.messages.append({"role": "assistant", "content": "What graphic and insights do you need?"}) for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) prompt = st.chat_input("Type your message here...", key="chat_input_bottom") if prompt: st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.spinner('Generating answer...'): try: # Determinar el año más reciente en los datos latest_year_labor = labor_data['year'].max() if 'year' in labor_data else datetime.now().year latest_year_fertility = fertility_data['year'].max() if 'year' in fertility_data else datetime.now().year latest_year = max(latest_year_labor, latest_year_fertility, datetime.now().year) country, start_date, end_date = extract_country_and_dates(prompt, all_countries, latest_year) graph_displayed = False # Analizar el prompt para determinar la intención del usuario if re.search(r'\b(labor|employment|job|workforce)\b', prompt, re.IGNORECASE): # Generar gráfica de datos laborales labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date) if labor_fig: st.session_state.messages.append( {"role": "assistant", "content": "Here is the labor data graphic:"}) with st.chat_message("assistant"): st.plotly_chart(labor_fig) graph_displayed = True elif re.search(r'\b(fertility|birth|population growth)\b', prompt, re.IGNORECASE): # Generar gráfica de datos de fertilidad fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date) if fertility_fig: st.session_state.messages.append( {"role": "assistant", "content": "Here is the fertility data graphic:"}) with st.chat_message("assistant"): st.plotly_chart(fertility_fig) graph_displayed = True else: # Si no se identifica una intención clara, intentar mostrar la gráfica de datos laborales primero labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date) if labor_fig: st.session_state.messages.append( {"role": "assistant", "content": "Here is the labor data graphic:"}) with st.chat_message("assistant"): st.plotly_chart(labor_fig) graph_displayed = True elif not graph_displayed: fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date) if fertility_fig: st.session_state.messages.append( {"role": "assistant", "content": "Here is the fertility data graphic:"}) with st.chat_message("assistant"): st.plotly_chart(fertility_fig) graph_displayed = True # Generar explicación usando RAG explanation = generate_rag_explanation(prompt, embeddings_array, contents_list) st.session_state.messages.append({"role": "assistant", "content": f"Explanation: {explanation}"}) with st.chat_message("assistant"): st.markdown(f"**Explanation:** {explanation}") except Exception as e: st.session_state.messages.append({"role": "assistant", "content": f"Error generating answer: {e}"}) with st.chat_message("assistant"): st.error(f"Error generating answer: {e}") if st.button("Clear chat"): st.session_state.messages = [] st.session_state.messages.append( {"role": "assistant", "content": "Chat has been cleared. What graphic and insights do you need now?"}) st.rerun()