graph_generator / gemma3test.py
juancamval's picture
Upload 2 files
ab15ee1 verified
raw
history blame contribute delete
13 kB
import streamlit as st
import os
import time
import pandas as pd
from dotenv import load_dotenv
from supabase import create_client, Client
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import plotly.graph_objects as go
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import re
# ---------------------------------------------------------------------------------
# Funciones auxiliares
# ---------------------------------------------------------------------------------
def extract_country_and_dates(prompt, countries):
country = None
start_date = None
end_date = None
# Buscar el pa铆s (insensible a may煤sculas y min煤sculas)
for c in countries:
if re.search(r'\b' + re.escape(c) + r'\b', prompt, re.IGNORECASE):
country = c
break
# Buscar rangos de a帽os con diferentes separadores (-, to, until, from ... to, between ... and)
date_ranges = re.findall(r'(\d{4})\s*(?:-|to|until|from.*?to|between.*?and)\s*(\d{4})', prompt, re.IGNORECASE)
if date_ranges:
start_date = date_ranges[0][0]
end_date = date_ranges[0][1]
else:
# Buscar un solo a帽o
single_years = re.findall(r'\b(\d{4})\b', prompt)
if single_years:
start_date = single_years[0]
end_date = single_years[0]
return country, start_date, end_date
def generate_plotly_graph(df, user_query, country=None, start_date=None, end_date=None):
relevant_data = df.copy()
if 'geo' in relevant_data.columns and country:
relevant_data = relevant_data[relevant_data['geo'].str.lower() == country.lower()]
if 'year' in relevant_data.columns:
relevant_data['year'] = pd.to_numeric(relevant_data['year'], errors='coerce').dropna().astype(int)
if start_date and end_date:
relevant_data = relevant_data[
(relevant_data['year'] >= int(start_date)) & (relevant_data['year'] <= int(end_date))
]
elif start_date:
relevant_data = relevant_data[relevant_data['year'] >= int(start_date)]
elif end_date:
relevant_data = relevant_data[relevant_data['year'] <= int(end_date)]
numeric_cols = relevant_data.select_dtypes(include=['number']).columns.tolist()
if 'year' in relevant_data.columns and numeric_cols:
fig = go.Figure()
for col in numeric_cols:
if col != 'year':
fig.add_trace(go.Scatter(x=relevant_data['year'], y=relevant_data[col], mode='lines+markers', name=col))
title = f"Data for {country if country else 'All Regions'}"
if start_date and end_date:
title += f" ({start_date}-{end_date})"
elif start_date:
title += f" (from {start_date})"
elif end_date:
title += f" (up to {end_date})"
# A帽adir t铆tulo y etiquetas de los ejes
fig.update_layout(
title=title,
xaxis_title="Year",
yaxis_title="Value" # Necesitaremos inferir o tener nombres de columnas m谩s descriptivos
)
return fig
else:
return None
# ---------------------------------------------------------------------------------
# Configuraci贸n de conexi贸n a Supabase
# ---------------------------------------------------------------------------------
load_dotenv()
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# Funci贸n para cargar datos de una tabla de Supabase
def load_data(table):
try:
if supabase:
response = supabase.from_(table).select("*").execute()
if hasattr(response, 'data'):
return pd.DataFrame(response.data)
elif hasattr(response, '_error'):
st.error(f"Error fetching data: {response._error}")
return pd.DataFrame()
else:
st.info("Response object does not have 'data' or known error attributes. Check the logs.")
return pd.DataFrame()
else:
st.error("Supabase client not initialized. Check environment variables.")
return pd.DataFrame()
except Exception as e:
st.error(f"An error occurred during data loading: {e}")
return pd.DataFrame()
# ---------------------------------------------------------------------------------
# Cargar datos iniciales
# ---------------------------------------------------------------------------------
labor_data = load_data("labor")
fertility_data = load_data("fertility")
# ---------------------------------------------------------------------------------
# Inicializaci贸n de modelos para RAG
# ---------------------------------------------------------------------------------
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
llm_pipeline = pipeline("text-generation", model="google/gemma-3-1b-it", token=os.getenv("HF_TOKEN"))
# ---------------------------------------------------------------------------------
# Generaci贸n de Embeddings y Metadatos (en memoria)
# ---------------------------------------------------------------------------------
embeddings_list = []
contents_list = []
metadatas_list = []
ids_list = []
for index, row in labor_data.iterrows():
doc = f"Country: {row['geo']}, Year: {row['year']}, Employment Rate: {row['labour_force'] if 'labour_force' in row else 'N/A'}"
embeddings_list.append(embedding_model.encode(doc))
contents_list.append(doc)
metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'labor'})
ids_list.append(f"labor_{index}")
for index, row in fertility_data.iterrows():
doc = f"Country: {row['geo']}, Year: {row['year']}, Fertility Rate: {row['fertility_rate'] if 'fertility_rate' in row else 'N/A'}"
embeddings_list.append(embedding_model.encode(doc))
contents_list.append(doc)
metadatas_list.append({'country': row['geo'], 'year': str(row['year']), 'source': 'fertility'})
ids_list.append(f"fertility_{index}")
embeddings_array = np.array(embeddings_list)
# ---------------------------------------------------------------------------------
# Funci贸n para recuperar documentos relevantes (en memoria)
# ---------------------------------------------------------------------------------
def retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents, top_k=3):
similarities = cosine_similarity([query_embedding], stored_embeddings)[0]
sorted_indices = np.argsort(similarities)[::-1]
relevant_documents = [contents[i] for i in sorted_indices[:top_k]]
return relevant_documents
# ---------------------------------------------------------------------------------
# Generaci贸n de la explicaci贸n usando RAG
# ---------------------------------------------------------------------------------
def generate_rag_explanation(user_query, stored_embeddings, contents):
query_embedding = embedding_model.encode(user_query)
relevant_docs = retrieve_relevant_documents_in_memory(query_embedding, stored_embeddings, contents)
if relevant_docs:
context = "\n".join(relevant_docs)
augmented_prompt = f"Based on the following information:\n\n{context}\n\nAnswer the question related to: {user_query}"
output = llm_pipeline(augmented_prompt, max_length=250, num_return_sequences=1)
return output[0]['generated_text']
else:
return "No relevant information found to answer your query."
# ---------------------------------------------------------------------------------
# Generar la lista de pa铆ses autom谩ticamente
# ---------------------------------------------------------------------------------
available_countries_labor = labor_data['geo'].unique().tolist() if 'geo' in labor_data.columns else []
available_countries_fertility = fertility_data['geo'].unique().tolist() if 'geo' in fertility_data.columns else []
all_countries = list(set(available_countries_labor + available_countries_fertility))
# ---------------------------------------------------------------------------------
# Configuraci贸n de la app en Streamlit
# ---------------------------------------------------------------------------------
st.set_page_config(page_title="GraphGen", page_icon="馃嚜馃嚭")
st.title("_Europe GraphGen_ 聽:blue[Graph generator] :flag-eu:")
st.caption("Mapping Europe's data with insights")
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.messages.append({"role": "assistant", "content": "What graphic and insights do you need?"})
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
prompt = st.chat_input("Type your message here...", key="chat_input_bottom")
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner('Generating answer...'):
try:
# Determinar el a帽o m谩s reciente en los datos
latest_year_labor = labor_data['year'].max() if 'year' in labor_data else datetime.now().year
latest_year_fertility = fertility_data['year'].max() if 'year' in fertility_data else datetime.now().year
latest_year = max(latest_year_labor, latest_year_fertility, datetime.now().year)
country, start_date, end_date = extract_country_and_dates(prompt, all_countries, latest_year)
graph_displayed = False
# Analizar el prompt para determinar la intenci贸n del usuario
if re.search(r'\b(labor|employment|job|workforce)\b', prompt, re.IGNORECASE):
# Generar gr谩fica de datos laborales
labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date)
if labor_fig:
st.session_state.messages.append(
{"role": "assistant", "content": "Here is the labor data graphic:"})
with st.chat_message("assistant"):
st.plotly_chart(labor_fig)
graph_displayed = True
elif re.search(r'\b(fertility|birth|population growth)\b', prompt, re.IGNORECASE):
# Generar gr谩fica de datos de fertilidad
fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date)
if fertility_fig:
st.session_state.messages.append(
{"role": "assistant", "content": "Here is the fertility data graphic:"})
with st.chat_message("assistant"):
st.plotly_chart(fertility_fig)
graph_displayed = True
else:
# Si no se identifica una intenci贸n clara, intentar mostrar la gr谩fica de datos laborales primero
labor_fig = generate_plotly_graph(labor_data, prompt, country, start_date, end_date)
if labor_fig:
st.session_state.messages.append(
{"role": "assistant", "content": "Here is the labor data graphic:"})
with st.chat_message("assistant"):
st.plotly_chart(labor_fig)
graph_displayed = True
elif not graph_displayed:
fertility_fig = generate_plotly_graph(fertility_data, prompt, country, start_date, end_date)
if fertility_fig:
st.session_state.messages.append(
{"role": "assistant", "content": "Here is the fertility data graphic:"})
with st.chat_message("assistant"):
st.plotly_chart(fertility_fig)
graph_displayed = True
# Generar explicaci贸n usando RAG
explanation = generate_rag_explanation(prompt, embeddings_array, contents_list)
st.session_state.messages.append({"role": "assistant", "content": f"Explanation: {explanation}"})
with st.chat_message("assistant"):
st.markdown(f"**Explanation:** {explanation}")
except Exception as e:
st.session_state.messages.append({"role": "assistant", "content": f"Error generating answer: {e}"})
with st.chat_message("assistant"):
st.error(f"Error generating answer: {e}")
if st.button("Clear chat"):
st.session_state.messages = []
st.session_state.messages.append(
{"role": "assistant", "content": "Chat has been cleared. What graphic and insights do you need now?"})
st.rerun()