mohammadhakimi's picture
Update app.py
bfa2732 verified
raw
history blame
2.31 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
model_name = "Meldashti/chatbot"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the fine-tuned model using PEFT
model = AutoModelForCausalLM.from_pretrained(model_name)
model = PeftModel.from_pretrained(model, model_name)
# Example: Assume you have a collection of text documents.
documents = [
Document(page_content="Document 1 content goes here..."),
Document(page_content="Document 2 content goes here..."),
# Add more documents as needed
]
# Initialize the Hugging Face embeddings model
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Split documents into smaller chunks for better retrieval
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
split_documents = text_splitter.split_documents(documents)
# Create the FAISS vector store
vector_store = FAISS.from_documents(split_documents, embeddings)
# Define a retriever that uses FAISS vector store
retriever = vector_store.as_retriever()
# Define the prompt template for the RAG pipeline
prompt_template = """
You are a helpful assistant. When a user asks a question, you will:
1. Retrieve relevant information from the knowledge base (provided by the retriever).
2. Answer the question based on that retrieved information.
Here is the context: {context}
Question: {question}
Answer:
"""
# Load the HuggingFace model and integrate it with LangChain
qa_chain = load_qa_chain(model=model, chain_type="stuff")
# Set up the RAG chain
rag_chain = RetrievalQA(
llm=HuggingFacePipeline(pipeline=qa_chain),
retriever=retriever,
return_source_documents=True,
)
def chat(message, history):
response = rag_chain.run(message)
return str(response)
demo = gr.ChatInterface(chat, type="messages", autofocus=False)
if __name__ == "__main__":
demo.launch()