RAG-SA / rag_hf.py
javiervzpucp's picture
Update rag_hf.py
842db18 verified
raw
history blame contribute delete
11.1 kB
import streamlit as st
import datetime
import pickle
import numpy as np
import rdflib
import torch
import os
import requests
from rdflib import Graph as RDFGraph, Namespace
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
# === CONFIGURATION ===
load_dotenv()
ENDPOINT_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EX = Namespace("http://example.org/lang/")
# === STREAMLIT UI CONFIG ===
st.set_page_config(
page_title="Language Atlas: South American Indigenous Languages",
page_icon="🌍",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'About': "## AI-powered analysis of endangered indigenous languages\n"
"Developed by Departamento AcadΓ©mico de Humanidades"
}
)
# === CUSTOM CSS ===
st.markdown("""
<style>
.header {
color: #2c3e50;
border-bottom: 2px solid #4f46e5;
padding-bottom: 0.5rem;
margin-bottom: 1.5rem;
}
.feature-card {
background-color: #f8fafc;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
border-left: 3px solid #4f46e5;
}
.response-card {
background-color: #fdfdfd;
color: #1f2937;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 2px 6px rgba(0,0,0,0.08);
margin: 1rem 0;
font-size: 1rem;
line-height: 1.5;
}
.language-card {
background-color: #f9fafb;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
border: 1px solid #e5e7eb;
}
.sidebar-section {
margin-bottom: 1.5rem;
}
.sidebar-title {
font-weight: 600;
color: #4f46e5;
}
.suggested-question {
padding: 0.5rem;
margin: 0.25rem 0;
border-radius: 4px;
cursor: pointer;
transition: all 0.2s;
}
.suggested-question:hover {
background-color: #f1f5f9;
}
.metric-badge {
display: inline-block;
background-color: #e8f4fc;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.85rem;
margin-right: 0.5rem;
}
.tech-badge {
background-color: #ecfdf5;
color: #065f46;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
}
</style>
""", unsafe_allow_html=True)
# === CORE FUNCTIONS ===
@st.cache_resource(show_spinner="Loading AI models and knowledge graphs...")
def load_all_components():
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
methods = {}
for label, suffix, ttl, matrix_path in [
("InfoMatch", "_hybrid", "grafo_ttl_hibrido.ttl", "embed_matrix_hybrid.npy"),
("LinkGraph", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
]:
with open(f"id_map{suffix}.pkl", "rb") as f:
id_map = pickle.load(f)
with open(f"grafo_embed{suffix}.pickle", "rb") as f:
G = pickle.load(f)
matrix = np.load(matrix_path)
rdf = RDFGraph()
rdf.parse(ttl, format="ttl")
methods[label] = (matrix, id_map, G, rdf)
return methods, embedder
def get_top_k(matrix, id_map, query, k, embedder):
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
vec = vec.cpu().numpy().astype("float32")
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
top_k_idx = np.argsort(sims)[-k:][::-1]
return [id_map[i] for i in top_k_idx]
def get_context(G, lang_id):
node = G.nodes.get(lang_id, {})
lines = [f"**Language:** {node.get('label', lang_id)}"]
if node.get("wikipedia_summary"):
lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
if node.get("wikidata_description"):
lines.append(f"**Wikidata:** {node['wikidata_description']}")
if node.get("wikidata_countries"):
lines.append(f"**Countries:** {node['wikidata_countries']}")
return "\n\n".join(lines)
def query_rdf(rdf, lang_id):
q = f"""
PREFIX ex: <http://example.org/lang/>
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
"""
try:
return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
except Exception as e:
return [("error", str(e))]
def generate_response(matrix, id_map, G, rdf, user_question, k, embedder):
ids = get_top_k(matrix, id_map, user_question, k, embedder)
context = [get_context(G, i) for i in ids]
rdf_facts = []
for i in ids:
rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
prompt = f"""<s>[INST]
You are an expert in South American indigenous languages.
Use strictly and only the information below to answer the user question in **English**.
- Do not infer or assume facts that are not explicitly stated.
- If the answer is unknown or insufficient, say \"I cannot answer with the available data.\"
- Limit your answer to 100 words.
### CONTEXT:
{chr(10).join(context)}
### RDF RELATIONS:
{chr(10).join(rdf_facts)}
### QUESTION:
{user_question}
Answer:
[/INST]"""
try:
res = requests.post(
ENDPOINT_URL,
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
json={"inputs": prompt}, timeout=60
)
out = res.json()
if isinstance(out, list) and "generated_text" in out[0]:
return out[0]["generated_text"].replace(prompt.strip(), "").strip(), ids, context, rdf_facts
return str(out), ids, context, rdf_facts
except Exception as e:
return str(e), ids, context, rdf_facts
# === MAIN APP ===
def main():
methods, embedder = load_all_components()
st.markdown("""
<div class="header">
<h1>🌍 Language Atlas: South American Indigenous Languages</h1>
</div>
""", unsafe_allow_html=True)
with st.expander("πŸ“Œ **Overview**", expanded=True):
st.markdown("""
This app provides **AI-powered analysis** of endangered indigenous languages in South America,
integrating knowledge graphs from **Glottolog, Wikipedia, and Wikidata**.
\n\n*This is version 1 and currently English-only. Spanish version coming soon!*
""")
with st.sidebar:
st.markdown("### πŸ“š Pontificia Universidad CatΓ³lica del PerΓΊ")
st.markdown("""
- <span class="tech-badge">Departamento de Humanidades</span>
- <span class="tech-badge">[email protected]</span>
- <span class="tech-badge">Suggestions? Contact us</span>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("### πŸš€ Quick Start")
st.markdown("""
1. **Type a question** in the input box
2. **Click 'Analyze'** to compare methods
3. **Explore results** with expandable details
""")
st.markdown("---")
st.markdown("### πŸ” Example Queries")
questions = [
"What languages are endangered in Brazil?",
"What languages are spoken in PerΓΊ?",
"Which languages are related to Quechua?",
"Where is Mapudungun spoken?"
]
for q in questions:
if st.markdown(f"<div class='suggested-question'>{q}</div>", unsafe_allow_html=True):
st.session_state.query = q
st.markdown("---")
st.markdown("### βš™οΈ Technical Details")
st.markdown("""
- <span class="tech-badge">Embeddings</span> Node2Vec vs. GraphSAGE
- <span class="tech-badge">Language Model</span> Mistral-7B-Instruct
- <span class="tech-badge">Knowledge Graph</span> RDF-based integration
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("### πŸ“‚ Data Sources")
st.markdown("""
- **Glottolog** (Language classification)
- **Wikipedia** (Textual summaries)
- **Wikidata** (Structured facts)
""")
st.markdown("---")
st.markdown("### πŸ“Š Analysis Parameters")
k = st.slider("Number of languages to analyze", 1, 10, 3)
st.markdown("---")
st.markdown("### πŸ”§ Advanced Options")
show_ctx = st.checkbox("Show context information", False)
show_rdf = st.checkbox("Show structured facts", False)
st.markdown("### πŸ“ Ask About Indigenous Languages")
query = st.text_input(
"Enter your question:",
value=st.session_state.get("query", ""),
label_visibility="collapsed",
placeholder="e.g. What languages are spoken in Peru?"
)
if st.button("Analyze", type="primary", use_container_width=True):
if not query:
st.warning("Please enter a question")
return
col1, col2 = st.columns(2)
for col, (label, method) in zip([col1, col2], methods.items()):
with col:
st.markdown(f"#### {label} Method")
st.caption({
"InfoMatch": "Node2Vec embeddings combining text and graph structure",
"LinkGraph": "GraphSAGE embeddings capturing network patterns"
}[label])
start = datetime.datetime.now()
response, lang_ids, context, rdf_data = generate_response(*method, query, k, embedder)
duration = (datetime.datetime.now() - start).total_seconds()
st.markdown(f"""
<div class="response-card">
{response}
<div style="margin-top: 1rem;">
<span class="metric-badge">⏱️ {duration:.2f}s</span>
<span class="metric-badge">🌐 {len(lang_ids)} languages</span>
</div>
</div>
""", unsafe_allow_html=True)
if show_ctx:
with st.expander(f"πŸ“– Context from {len(lang_ids)} languages"):
for lang_id, ctx in zip(lang_ids, context):
st.markdown(f"<div class='language-card'>{ctx}</div>", unsafe_allow_html=True)
if show_rdf:
with st.expander("πŸ”— Structured facts (RDF)"):
st.code("\n".join(rdf_data))
st.markdown("---")
st.markdown("""
<div style="font-size: 0.8rem; color: #64748b; text-align: center;">
<b>πŸ“Œ Note:</b> This tool is designed for researchers, linguists, and cultural preservationists.
For best results, use specific questions about languages, families, or regions.
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()