mohammadhakimi commited on
Commit
ff17315
·
verified ·
1 Parent(s): ffc55b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -31
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  from peft import PeftModel
4
- from langchain.text_splitter import CharacterTextSplitter
5
- from langchain.docstore.document import Document
6
  from langchain_community.llms import HuggingFacePipeline
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
- from langchain.chains.retrieval_qa.base import RetrievalQA
 
 
10
 
11
  # Model and Tokenizer
12
  model_name = "Meldashti/chatbot"
@@ -15,55 +15,64 @@ tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B")
15
 
16
  # Merge PEFT weights with base model
17
  model = PeftModel.from_pretrained(base_model, model_name)
18
- model = model.merge_and_unload() # This merges the PEFT weights into the base model
19
 
20
- # Set up the text-generation pipeline
21
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
22
 
23
- # Use the HuggingFacePipeline from langchain_community
24
  llm = HuggingFacePipeline(pipeline=generator)
25
 
26
- # Initialize Hugging Face embeddings model
27
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
28
 
29
- # Sample documents
30
  documents = [
31
- Document(page_content="Document 1 content goes here..."),
32
- Document(page_content="Document 2 content goes here..."),
33
- # Add more documents as needed
34
  ]
35
 
36
- # Split documents into smaller chunks for better retrieval
37
- text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
38
  split_documents = text_splitter.split_documents(documents)
39
 
40
- # Create FAISS vector store
41
  vector_store = FAISS.from_documents(split_documents, embeddings)
 
42
 
43
- # Define a retriever that uses FAISS vector store
44
- retriever = vector_store.as_retriever()
45
-
46
- # Create Retrieval QA Chain
47
  rag_chain = RetrievalQA.from_chain_type(
48
- llm=llm,
49
- chain_type="stuff",
50
  retriever=retriever
51
  )
52
 
53
- # Define the chat function
54
  def chat(message, history):
55
  print(f"Received message: {message}")
56
  try:
57
- response = rag_chain.invoke({"query": message})
58
- print(f"Response generated: {response}")
59
- return str(response['result'])
 
 
 
 
 
 
60
  except Exception as e:
61
- print(f"Error generating response: {e}")
62
- return "Sorry, I couldn't generate a response."
63
 
64
- # Set up the Gradio interface
65
  demo = gr.ChatInterface(chat, type="messages", autofocus=False)
66
 
67
- # Launch the app
68
  if __name__ == "__main__":
69
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  from peft import PeftModel
 
 
4
  from langchain_community.llms import HuggingFacePipeline
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
+ from langchain.text_splitter import CharacterTextSplitter
8
+ from langchain.docstore.document import Document
9
+ from langchain.chains import RetrievalQA
10
 
11
  # Model and Tokenizer
12
  model_name = "Meldashti/chatbot"
 
15
 
16
  # Merge PEFT weights with base model
17
  model = PeftModel.from_pretrained(base_model, model_name)
18
+ model = model.merge_and_unload()
19
 
20
+ # Simplified pipeline with minimal parameters
21
+ generator = pipeline(
22
+ "text-generation",
23
+ model=model,
24
+ tokenizer=tokenizer,
25
+ max_new_tokens=50, # Very low to test responsiveness
26
+ do_sample=False
27
+ )
28
 
29
+ # LLM wrapper
30
  llm = HuggingFacePipeline(pipeline=generator)
31
 
32
+ # Embeddings
33
+ embeddings = HuggingFaceEmbeddings(model_name="paraphrase-MiniLM-L3-v2")
34
 
35
+ # Sample documents (minimal)
36
  documents = [
37
+ Document(page_content="Example document about food industry caps"),
38
+ Document(page_content="Information about manufacturing processes")
 
39
  ]
40
 
41
+ # Text splitting
42
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20)
43
  split_documents = text_splitter.split_documents(documents)
44
 
45
+ # Vector store
46
  vector_store = FAISS.from_documents(split_documents, embeddings)
47
+ retriever = vector_store.as_retriever(search_kwargs={"k": 2})
48
 
49
+ # Retrieval QA Chain
 
 
 
50
  rag_chain = RetrievalQA.from_chain_type(
51
+ llm=llm,
52
+ chain_type="stuff",
53
  retriever=retriever
54
  )
55
 
56
+ # Chat function with extensive logging
57
  def chat(message, history):
58
  print(f"Received message: {message}")
59
  try:
60
+ # Add timeout mechanism
61
+ import timeout_decorator
62
+
63
+ @timeout_decorator.timeout(10) # 10 seconds timeout
64
+ def get_response():
65
+ response = rag_chain.invoke({"query": message})
66
+ return str(response['result'])
67
+
68
+ return get_response()
69
  except Exception as e:
70
+ print(f"Error generating response: {type(e)}, {e}")
71
+ return f"An error occurred: {str(e)}"
72
 
73
+ # Gradio interface
74
  demo = gr.ChatInterface(chat, type="messages", autofocus=False)
75
 
76
+ # Launch
77
  if __name__ == "__main__":
78
+ demo.launch(debug=True)