|
import os |
|
import streamlit as st |
|
import faiss |
|
import pickle |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from groq import Groq |
|
|
|
|
|
DATASET_NAME = "neural-bridge/rag-dataset-1200" |
|
MODEL_NAME = "all-MiniLM-L6-v2" |
|
INDEX_FILE = "faiss_index.pkl" |
|
DOCS_FILE = "contexts.pkl" |
|
|
|
|
|
client = Groq(api_key=os.environ.get("MY_KEY")) |
|
|
|
|
|
st.set_page_config(page_title="RAG App", layout="wide") |
|
st.title("π§ Retrieval-Augmented Generation (RAG) with Groq") |
|
|
|
|
|
@st.cache_resource |
|
def setup_database(): |
|
st.info("Setting up vector database...") |
|
progress = st.progress(0) |
|
|
|
|
|
dataset = load_dataset(DATASET_NAME, split="train") |
|
contexts = [entry["context"] for entry in dataset] |
|
progress.progress(25) |
|
|
|
|
|
embedder = SentenceTransformer(MODEL_NAME) |
|
embeddings = embedder.encode(contexts, show_progress_bar=True) |
|
progress.progress(50) |
|
|
|
|
|
dimension = embeddings[0].shape[0] |
|
faiss_index = faiss.IndexFlatL2(dimension) |
|
faiss_index.add(embeddings) |
|
progress.progress(75) |
|
|
|
|
|
with open(INDEX_FILE, "wb") as f: |
|
pickle.dump(faiss_index, f) |
|
with open(DOCS_FILE, "wb") as f: |
|
pickle.dump(contexts, f) |
|
|
|
progress.progress(100) |
|
st.success("Database setup complete!") |
|
return faiss_index, contexts |
|
|
|
|
|
if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE): |
|
with open(INDEX_FILE, "rb") as f: |
|
faiss_index = pickle.load(f) |
|
with open(DOCS_FILE, "rb") as f: |
|
all_contexts = pickle.load(f) |
|
st.info("Loaded existing database.") |
|
else: |
|
faiss_index, all_contexts = setup_database() |
|
|
|
|
|
sample_questions = [ |
|
"What is the purpose of the RAG dataset?", |
|
"How does Falcon RefinedWeb contribute to this dataset?", |
|
"What are the benefits of using retrieval-augmented generation?", |
|
"Explain the structure of the RAG-1200 dataset.", |
|
] |
|
|
|
st.subheader("Ask a question based on the dataset:") |
|
question = st.text_input("Enter your question:", value=sample_questions[0]) |
|
|
|
if st.button("Ask"): |
|
if question.strip() == "": |
|
st.warning("Please enter a question.") |
|
else: |
|
with st.spinner("Retrieving and generating answer..."): |
|
|
|
embedder = SentenceTransformer(MODEL_NAME) |
|
query_embedding = embedder.encode([question]) |
|
D, I = faiss_index.search(query_embedding, k=1) |
|
|
|
|
|
context = all_contexts[I[0][0]] |
|
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" |
|
|
|
|
|
response = client.chat.completions.create( |
|
messages=[{"role": "user", "content": prompt}], |
|
model="llama3-70b-8192" |
|
) |
|
|
|
answer = response.choices[0].message.content |
|
st.success("Answer:") |
|
st.markdown(answer) |
|
|
|
with st.expander("π Retrieved Context"): |
|
st.markdown(context) |
|
|