mohammadhakimi's picture
Update app.py
6249edf verified
raw
history blame
2.66 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFacePipeline
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
import os
import pinecone
import numpy as np
from langchain.vectorstores import Pinecone
from langchain.schema import Document
# Initialize Pinecone
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "")
PINECONE_INDEX = "arolchatbot" # e.g., "us-west1-gcp-free"
# Connect to Pinecone
pinecone.init(api_key=PINECONE_API_KEY)
index = pinecone.Index(INDEX_NAME)
embedder = SentenceTransformer('thenlper/gte-large')
vector_store = Pinecone(index, embedder.embed_query, "text")
embeddings = HuggingFaceEmbeddings(model=sentence_model)
# Model and Tokenizer
model_name = "Meldashti/chatbot"
base_model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-3B")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B")
# Merge PEFT weights with base model
model = PeftModel.from_pretrained(base_model, model_name)
model = model.merge_and_unload()
# Simplified pipeline with minimal parameters
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=150
)
# LLM wrapper
llm = HuggingFacePipeline(pipeline=generator)
# Wrap the Pinecone index with LangChain's Pinecone wrapper
vector_store = Pinecone(index, embeddings.embed_query, "text")
# Text splitting
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
split_documents = text_splitter.split_documents(documents)
# Vector store
vector_store = FAISS.from_documents(split_documents, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
# Retrieval QA Chain
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever
)
# Chat function with extensive logging
def chat(message, history):
print(f"Received message: {message}")
try:
response = rag_chain.invoke({"query": message})
print(response)
return str(response['result'].split("Helpful Answer: ")[1])
except Exception as e:
print(f"Error generating response: {type(e)}, {e}")
return f"An error occurred: {str(e)}"
# Gradio interface
demo = gr.ChatInterface(chat, type="messages", autofocus=False)
# Launch
if __name__ == "__main__":
demo.launch(debug=True)