Sumkh commited on
Commit
72b019e
·
verified ·
1 Parent(s): fa454d9
Files changed (3) hide show
  1. Dockerfile +31 -0
  2. app.py +1002 -0
  3. requirements.txt +30 -0
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Install vLLM dependencies
4
+ RUN pip install vllm gradio bitsandbytes transformers accelerate wget
5
+
6
+ # Copy your Gradio app files
7
+ COPY app.py .
8
+ COPY requirements.txt .
9
+ RUN pip install -r requirements.txt
10
+
11
+ # Download chat template and model
12
+ RUN wget -O /tmp/tool_chat_template_llama3.1_json.jinja \
13
+ https://github.com/vllm-project/vllm/raw/refs/heads/main/examples/tool_chat_template_llama3.1_json.jinja && \
14
+ huggingface-cli download --resume-download unsloth/llama-3-8b-Instruct-bnb-4bit --local-dir /app/models
15
+
16
+ # Expose Gradio port
17
+ EXPOSE 7860
18
+
19
+ # Start vLLM and Gradio
20
+ CMD vllm.entrypoints.openai.api_server \
21
+ --model /app/models \
22
+ --enable-auto-tool-choice \
23
+ --tool-call-parser llama3_json \
24
+ --chat-template /tmp/tool_chat_template_llama3.1_json.jinja \
25
+ --quantization bitsandbytes \
26
+ --load-format bitsandbytes \
27
+ --dtype half \
28
+ --max-model-len 8192 \
29
+ --download-dir models/vllm \
30
+ --host 0.0.0.0 \
31
+ --port 8000 & python app.py
app.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+ import sys
3
+
4
+ import os
5
+ #from huggingface_hub import login
6
+ import gradio as gr
7
+ import json
8
+ import csv
9
+ import hashlib
10
+ import uuid
11
+ import logging
12
+ from typing import Annotated, List, Dict, Sequence, TypedDict
13
+
14
+ # LangChain & related imports
15
+ from langchain_core.runnables import RunnableConfig
16
+ from langchain_core.tools import tool, StructuredTool
17
+ from pydantic import BaseModel, Field
18
+
19
+ from langchain_huggingface import HuggingFaceEmbeddings
20
+ from langchain_chroma import Chroma
21
+ from langchain_core.documents import Document
22
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
23
+ from langchain.retrievers import EnsembleRetriever
24
+
25
+ # Extraction for Documents
26
+ from langchain_docling.loader import ExportType
27
+ from langchain_docling import DoclingLoader
28
+ from docling.chunking import HybridChunker
29
+
30
+ # Extraction for HTML
31
+ from langchain_community.document_loaders import WebBaseLoader
32
+ from urllib.parse import urlparse
33
+
34
+ #from langchain_groq import ChatGroq
35
+ from langchain_openai import ChatOpenAI
36
+ from langgraph.prebuilt import InjectedStore
37
+ from langgraph.store.base import BaseStore
38
+ from langgraph.store.memory import InMemoryStore
39
+ from langgraph.checkpoint.memory import MemorySaver
40
+ from langchain.embeddings import init_embeddings
41
+ from langgraph.graph import StateGraph
42
+ from langgraph.graph.message import add_messages
43
+ from langgraph.prebuilt import ToolNode, tools_condition
44
+ from langchain_core.messages import (
45
+ SystemMessage,
46
+ AIMessage,
47
+ HumanMessage,
48
+ BaseMessage,
49
+ ToolMessage,
50
+ )
51
+
52
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
53
+ logger = logging.getLogger(__name__)
54
+
55
+ # Suppress all library logs at or below WARNING for user experience:
56
+ logging.disable(logging.WARNING)
57
+
58
+
59
+ EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
60
+
61
+ # =============================================================================
62
+ # Document Extraction Functions
63
+ # =============================================================================
64
+
65
+ def extract_documents(doc_path: str) -> List[str]:
66
+ """
67
+ Recursively collects all file paths from folder 'doc_path'.
68
+ Used by ExtractDocument.load_files() to find documents to parse.
69
+ """
70
+ extracted_docs = []
71
+
72
+ for root, _, files in os.walk(doc_path):
73
+ for file in files:
74
+ file_path = os.path.join(root, file)
75
+ extracted_docs.append(file_path)
76
+ return extracted_docs
77
+
78
+
79
+ def _generate_uuid(page_content: str) -> str:
80
+ """Generate a UUID for a chunk of text using MD5 hashing."""
81
+ md5_hash = hashlib.md5(page_content.encode()).hexdigest()
82
+ return str(uuid.UUID(md5_hash[0:32]))
83
+
84
+
85
+ def load_file(file_path: str) -> List[Document]:
86
+ """
87
+ Load a file from the given path and return a list of Document objects.
88
+ """
89
+ _documents = []
90
+
91
+ # Load the file and extract the text chunks
92
+ try:
93
+ loader = DoclingLoader(
94
+ file_path = file_path,
95
+ export_type = ExportType.DOC_CHUNKS,
96
+ chunker = HybridChunker(tokenizer=EMBED_MODEL_ID),
97
+ )
98
+ docs = loader.load()
99
+ logger.info(f"Total parsed doc-chunks: {len(docs)} from Source: {file_path}")
100
+
101
+ for d in docs:
102
+ # Tag each document's chunk with the source file and a unique ID
103
+ doc = Document(
104
+ page_content=d.page_content,
105
+ metadata={
106
+ "source": file_path,
107
+ "doc_id": _generate_uuid(d.page_content),
108
+ "source_type": "file",
109
+ }
110
+ )
111
+ _documents.append(doc)
112
+ logger.info(f"Total generated LangChain document chunks: {len(_documents)}\n.")
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error loading file: {file_path}. Exception: {e}\n.")
116
+
117
+ return _documents
118
+
119
+
120
+ # Define function to load documents from a folder
121
+ def load_files_from_folder(doc_path: str) -> List[Document]:
122
+ """
123
+ Load documents from the given folder path and return a list of Document objects.
124
+ """
125
+ _documents = []
126
+ # Extract all files path from the given folder
127
+ extracted_docs = extract_documents(doc_path)
128
+
129
+ # Iterate through each document and extract the text chunks
130
+ for file_path in extracted_docs:
131
+ _documents.extend(load_file(file_path))
132
+
133
+ return _documents
134
+
135
+ # =============================================================================
136
+ # Load structured data in csv file to LangChain Document format
137
+ def load_mcq_csvfiles(file_path: str) -> List[Document]:
138
+ """
139
+ Load structured data in mcq csv file from the given file path and return a list of Document object.
140
+ Expected format: each row of csv is comma separated into "mcq_number", "mcq_type", "text_content"
141
+ """
142
+ _documents = []
143
+
144
+ # iterate through each csv file and load each row into _dict_per_question format
145
+ # Ensure we process only CSV files
146
+ if not file_path.endswith(".csv"):
147
+ return _documents # Skip non-CSV files
148
+ try:
149
+ # Open and read the CSV file
150
+ with open(file_path, mode='r', encoding='utf-8') as file:
151
+ reader = csv.DictReader(file)
152
+ for row in reader:
153
+ # Ensure required columns exist in the row
154
+ if not all(k in row for k in ["mcq_number", "mcq_type", "text_content"]): # Ensure required columns exist and exclude header
155
+ logger.error(f"Skipping row due to missing fields: {row}")
156
+ continue
157
+ # Tag each row of csv is comma separated into "mcq_number", "mcq_type", "text_content"
158
+ doc = Document(
159
+ page_content = row["text_content"], # text_content segment is separated by "|"
160
+ metadata={
161
+ "source": f"{file_path}_{row['mcq_number']}", # file_path + mcq_number
162
+ "doc_id": _generate_uuid(f"{file_path}_{row['mcq_number']}"), # Unique ID
163
+ "source_type": row["mcq_type"], # MCQ type
164
+ }
165
+ )
166
+ _documents.append(doc)
167
+ logger.info(f"Successfully loaded {len(_documents)} LangChain document chunks from {file_path}.")
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error loading file: {file_path}. Exception: {e}\n.")
171
+
172
+ return _documents
173
+
174
+ # Define function to load documents from a folder for structured data in csv file
175
+ def load_files_from_folder_mcq(doc_path: str) -> List[Document]:
176
+ """
177
+ Load mcq csv file from the given folder path and return a list of Document objects.
178
+ """
179
+ _documents = []
180
+ # Extract all files path from the given folder
181
+ extracted_docs = [
182
+ os.path.join(doc_path, file) for file in os.listdir(doc_path)
183
+ if file.endswith(".csv") # Process only CSV files
184
+ ]
185
+
186
+ # Iterate through each document and extract the text chunks
187
+ for file_path in extracted_docs:
188
+ _documents.extend(load_mcq_csvfiles(file_path))
189
+
190
+ return _documents
191
+
192
+
193
+ # =============================================================================
194
+ # Website Extraction Functions
195
+ # =============================================================================
196
+ def _generate_uuid(page_content: str) -> str:
197
+ """Generate a UUID for a chunk of text using MD5 hashing."""
198
+ md5_hash = hashlib.md5(page_content.encode()).hexdigest()
199
+ return str(uuid.UUID(md5_hash[0:32]))
200
+
201
+ def ensure_scheme(url):
202
+ parsed_url = urlparse(url)
203
+ if not parsed_url.scheme:
204
+ return 'http://' + url # Default to http, or use 'https://' if preferred
205
+ return url
206
+
207
+ def extract_html(url: List[str]) -> List[Document]:
208
+ if isinstance(url, str):
209
+ url = [url]
210
+ """
211
+ Extracts text from the HTML content of web pages listed in 'web_path'.
212
+ Returns a list of LangChain 'Document' objects.
213
+ """
214
+ # Ensure all URLs have a scheme
215
+ web_paths = [ensure_scheme(u) for u in url]
216
+
217
+ loader = WebBaseLoader(web_paths)
218
+ loader.requests_per_second = 1
219
+ docs = loader.load()
220
+
221
+ # Iterate through each document, clean the content, removing excessive line return and store it in a LangChain Document
222
+ _documents = []
223
+ for doc in docs:
224
+ # Clean the concent
225
+ doc.page_content = doc.page_content.strip()
226
+ doc.page_content = doc.page_content.replace("\n", " ")
227
+ doc.page_content = doc.page_content.replace("\r", " ")
228
+ doc.page_content = doc.page_content.replace("\t", " ")
229
+ doc.page_content = doc.page_content.replace(" ", " ")
230
+ doc.page_content = doc.page_content.replace(" ", " ")
231
+
232
+ # Store it in a LangChain Document
233
+ web_doc = Document(
234
+ page_content=doc.page_content,
235
+ metadata={
236
+ "source": doc.metadata.get("source"),
237
+ "doc_id": _generate_uuid(doc.page_content),
238
+ "source_type": "web"
239
+ }
240
+ )
241
+ _documents.append(web_doc)
242
+ return _documents
243
+
244
+ # =============================================================================
245
+ # Vector Store Initialisation
246
+ # =============================================================================
247
+
248
+ embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_ID)
249
+
250
+ # Initialise vector stores
251
+ general_vs = Chroma(
252
+ collection_name="general_vstore",
253
+ embedding_function=embedding_model,
254
+ persist_directory="./general_db"
255
+ )
256
+
257
+ mcq_vs = Chroma(
258
+ collection_name="mcq_vstore",
259
+ embedding_function=embedding_model,
260
+ persist_directory="./mcq_db"
261
+ )
262
+
263
+ in_memory_vs = Chroma(
264
+ collection_name="in_memory_vstore",
265
+ embedding_function=embedding_model
266
+ )
267
+
268
+ # Split the documents into smaller chunks for better embedding coverage
269
+ def split_text_into_chunks(docs: List[Document]) -> List[Document]:
270
+ """
271
+ Splits a list of Documents into smaller text chunks using
272
+ RecursiveCharacterTextSplitter while preserving metadata.
273
+ Returns a list of Document objects.
274
+ """
275
+ if not docs:
276
+ return []
277
+ splitter = RecursiveCharacterTextSplitter(
278
+ chunk_size=1000, # Split into chunks of 1000 characters
279
+ chunk_overlap=200, # Overlap by 200 characters
280
+ add_start_index=True
281
+ )
282
+ chunked_docs = splitter.split_documents(docs)
283
+ return chunked_docs # List of Document objects
284
+
285
+
286
+ # =============================================================================
287
+ # Retrieval Tools
288
+ # =============================================================================
289
+
290
+ # Define a simple similarity search retrieval tool on msq_vs
291
+ class MCQRetrievalTool(BaseModel):
292
+ input: str = Field(..., title="input", description="Search topic.")
293
+ k: int = Field(2, title="Number of Results", description="The number of results to retrieve.")
294
+
295
+ def mcq_retriever(input: str, k: int = 2) -> List[str]:
296
+ # Retrieve the top k most similar mcq question documents from the vector store
297
+ docs_func = mcq_vs.as_retriever(
298
+ search_type="similarity",
299
+ search_kwargs={
300
+ 'k': k,
301
+ 'filter':{"source_type": "mcq_question"}
302
+ },
303
+ )
304
+ docs_qns = docs_func.invoke(input, k=k)
305
+
306
+ # Extract the document IDs from the retrieved documents
307
+ doc_ids = [d.metadata.get("doc_id") for d in docs_qns if "doc_id" in d.metadata]
308
+
309
+ # Retrieve full documents based on the doc_ids
310
+ docs = mcq_vs.get(where = {'doc_id': {"$in":doc_ids}})
311
+
312
+ qns_list = {}
313
+ for i, d in enumerate(docs['metadatas']):
314
+ qns_list[d['source'] + " " + d['source_type']] = docs['documents'][i]
315
+
316
+ return qns_list
317
+
318
+ # Create a StructuredTool from the function
319
+ mcq_retriever_tool = StructuredTool.from_function(
320
+ func = mcq_retriever,
321
+ name = "MCQ Retrieval Tool",
322
+ description = (
323
+ """
324
+ Use this tool to retrieve MCQ questions set when Human asks to generate a quiz related to a topic.
325
+ DO NOT GIVE THE ANSWERS to Human before Human has answered all the questions.
326
+
327
+ If Human give answers for questions you do not know, SAY you do not have the questions for the answer
328
+ and ASK if the Human want you to generate a new quiz and then SAVE THE QUIZ with Summary Tool before ending the conversation.
329
+
330
+
331
+ Input must be a JSON string with the schema:
332
+ - input (str): The search topic to retrieve MCQ questions set related to the topic.
333
+ - k (int): Number of question set to retrieve.
334
+ Example usage: input='What is AI?', k=5
335
+
336
+ Returns:
337
+ - A dict of MCQ questions:
338
+ Key: 'metadata of question' e.g. './Documents/mcq/mcq.csv_Qn31 mcq_question' with suffix ['question', 'answer', 'answer_reason', 'options', 'wrong_options_reason']
339
+ Value: Text Content
340
+
341
+ """
342
+ ),
343
+ args_schema = MCQRetrievalTool,
344
+ response_format="content",
345
+ return_direct = False, # Return the response as a list of strings
346
+ verbose = False # To log tool's progress
347
+ )
348
+
349
+ # -----------------------------------------------------------------------------
350
+
351
+ # Retrieve more documents with higher diversity using MMR (Maximal Marginal Relevance) from the general vector store
352
+ # Useful if the dataset has many similar documents
353
+ class GenRetrievalTool(BaseModel):
354
+ input: str = Field(..., title="input", description="User query.")
355
+ k: int = Field(2, title="Number of Results", description="The number of results to retrieve.")
356
+
357
+ def gen_retriever(input: str, k: int = 2) -> List[str]:
358
+ # Use retriever of vector store to retrieve documents
359
+ docs_func = general_vs.as_retriever(
360
+ search_type="mmr",
361
+ search_kwargs = {'k': k, 'lambda_mult': 0.25}
362
+ )
363
+ docs = docs_func.invoke(input, k=k)
364
+ return [d.page_content for d in docs]
365
+
366
+ # Create a StructuredTool from the function
367
+ general_retriever_tool = StructuredTool.from_function(
368
+ func = gen_retriever,
369
+ name = "Assistant References Retrieval Tool",
370
+ description = (
371
+ """
372
+ Use this tool to retrieve reference information from Assistant reference database for Human queries related to a topic or
373
+ and when Human asked to generate guides to learn or study about a topic.
374
+
375
+ Input must be a JSON string with the schema:
376
+ - input (str): The user query.
377
+ - k (int): Number of results to retrieve.
378
+ Example usage: input='What is AI?', k=5
379
+ Returns:
380
+ - A list of retrieved document's content string.
381
+ """
382
+ ),
383
+ args_schema = GenRetrievalTool,
384
+ response_format="content",
385
+ return_direct = False, # Return the content of the documents
386
+ verbose = False # To log tool's progress
387
+ )
388
+
389
+ # -----------------------------------------------------------------------------
390
+
391
+ # Retrieve more documents with higher diversity using MMR (Maximal Marginal Relevance) from the in-memory vector store
392
+ # Query in-memory vector store only
393
+ class InMemoryRetrievalTool(BaseModel):
394
+ input: str = Field(..., title="input", description="User query.")
395
+ k: int = Field(2, title="Number of Results", description="The number of results to retrieve.")
396
+
397
+ def in_memory_retriever(input: str, k: int = 2) -> List[str]:
398
+ # Use retriever of vector store to retrieve documents
399
+ docs_func = in_memory_vs.as_retriever(
400
+ search_type="mmr",
401
+ search_kwargs = {'k': k, 'lambda_mult': 0.25}
402
+ )
403
+ docs = docs_func.invoke(input, k=k)
404
+ return [d.page_content for d in docs]
405
+
406
+ # Create a StructuredTool from the function
407
+ in_memory_retriever_tool = StructuredTool.from_function(
408
+ func = in_memory_retriever,
409
+ name = "In-Memory Retrieval Tool",
410
+ description = (
411
+ """
412
+ Use this tool when Human ask Assistant to retrieve information from documents that Human has uploaded.
413
+
414
+ Input must be a JSON string with the schema:
415
+ - input (str): The user query.
416
+ - k (int): Number of results to retrieve.
417
+ """
418
+ ),
419
+ args_schema = InMemoryRetrievalTool,
420
+ response_format="content",
421
+ return_direct = False, # Whether to return the tool’s output directly
422
+ verbose = False # To log tool's progress
423
+ )
424
+
425
+ # -----------------------------------------------------------------------------
426
+
427
+ # Web Extraction Tool
428
+ class WebExtractionRequest(BaseModel):
429
+ input: str = Field(..., title="input", description="Search text.")
430
+ url: str = Field(
431
+ ...,
432
+ title="url",
433
+ description="Web URL(s) to extract content from. If multiple URLs, separate them with a comma."
434
+ )
435
+ k: int = Field(5, title="Number of Results", description="The number of results to retrieve.")
436
+
437
+ # Extract content from a web URL, load into in_memory_vstore
438
+ def extract_web_path_tool(input: str, url: str, k: int = 5) -> List[str]:
439
+ if isinstance(url, str):
440
+ url = [url]
441
+ """
442
+ Extract content from the web URLs based on user's input.
443
+ Args:
444
+ - input: The input text to search for.
445
+ - url: URLs to extract content from.
446
+ - k: Number of results to retrieve.
447
+ Returns:
448
+ - A list of retrieved document's content string.
449
+ """
450
+ # Extract content from the web
451
+ html_docs = extract_html(url)
452
+ if not html_docs:
453
+ return f"No content extracted from {url}."
454
+
455
+ # Split the documents into smaller chunks for better embedding coverage
456
+ chunked_texts = split_text_into_chunks(html_docs)
457
+ in_memory_vs.add_documents(chunked_texts) # Add the chunked texts to the in-memory vector store
458
+
459
+ # Extract content from the in-memory vector store
460
+ # Use retriever of vector store to retrieve documents
461
+ docs_func = in_memory_vs.as_retriever(
462
+ search_type="mmr",
463
+ search_kwargs={
464
+ 'k': k,
465
+ 'lambda_mult': 0.25,
466
+ 'filter':{"source": {"$in": url}}
467
+ },
468
+ )
469
+ docs = docs_func.invoke(input, k=k)
470
+ return [d.page_content for d in docs]
471
+
472
+ # Create a StructuredTool from the function
473
+ web_extraction_tool = StructuredTool.from_function(
474
+ func = extract_web_path_tool,
475
+ name = "Web Extraction Tool",
476
+ description = (
477
+ "Assistant should use this tool to extract content from web URLs based on user's input, "
478
+ "Web extraction is initially load into database and then return k: Number of results to retrieve"
479
+ ),
480
+ args_schema = WebExtractionRequest,
481
+ return_direct = False, # Whether to return the tool’s output directly
482
+ verbose = False # To log tool's progress
483
+ )
484
+
485
+ # -----------------------------------------------------------------------------
486
+
487
+ # Ensemble Retrieval from General and In-Memory Vector Stores
488
+ class EnsembleRetrievalTool(BaseModel):
489
+ input: str = Field(..., title="input", description="User query.")
490
+ k: int = Field(5, title="Number of Results", description="Number of results.")
491
+
492
+ def ensemble_retriever(input: str, k: int = 5) -> List[str]:
493
+ # Use retriever of vector store to retrieve documents
494
+ general_retrieval = general_vs.as_retriever(
495
+ search_type="mmr",
496
+ search_kwargs = {'k': k, 'lambda_mult': 0.25}
497
+ )
498
+ in_memory_retrieval = in_memory_vs.as_retriever(
499
+ search_type="mmr",
500
+ search_kwargs = {'k': k, 'lambda_mult': 0.25}
501
+ )
502
+
503
+ ensemble_retriever = EnsembleRetriever(
504
+ retrievers=[general_retrieval, in_memory_retrieval],
505
+ weights=[0.5, 0.5]
506
+ )
507
+ docs = ensemble_retriever.invoke(input)
508
+ return [d.page_content for d in docs]
509
+
510
+ # Create a StructuredTool from the function
511
+ ensemble_retriever_tool = StructuredTool.from_function(
512
+ func = ensemble_retriever,
513
+ name = "Ensemble Retriever Tool",
514
+ description = (
515
+ """
516
+ Use this tool to retrieve information from reference database and
517
+ extraction of documents that Human has uploaded.
518
+
519
+ Input must be a JSON string with the schema:
520
+ - input (str): The user query.
521
+ - k (int): Number of results to retrieve.
522
+ """
523
+ ),
524
+ args_schema = EnsembleRetrievalTool,
525
+ response_format="content",
526
+ return_direct = False
527
+ )
528
+
529
+
530
+ ###############################################################################
531
+ # LLM Model Setup
532
+ ###############################################################################
533
+
534
+ TEMPERATURE = 0.5
535
+ model = ChatOpenAI(
536
+ model="unsloth/llama-3-8b-Instruct-bnb-4bit",
537
+ temperature=TEMPERATURE,
538
+ timeout=None,
539
+ max_retries=2,
540
+ api_key="not_required",
541
+ base_url="http://localhost:8000", # Use the VLLM instance URL
542
+ verbose=True
543
+ )
544
+
545
+ # model = ChatGroq(
546
+ # model_name="deepseek-r1-distill-llama-70b",
547
+ # temperature=TEMPERATURE,
548
+ # api_key=GROQ_API_KEY,
549
+ # verbose=True
550
+ # )
551
+
552
+ ###############################################################################
553
+ # 1. Initialize memory + config
554
+ ###############################################################################
555
+ in_memory_store = InMemoryStore(
556
+ index={
557
+ "embed": init_embeddings("huggingface:sentence-transformers/all-MiniLM-L6-v2"),
558
+ "dims": 384, # Embedding dimensions
559
+ }
560
+ )
561
+
562
+ # A memory saver to checkpoint conversation states
563
+ checkpointer = MemorySaver()
564
+
565
+ # Initialize config with user & thread info
566
+ config = {}
567
+ config["configurable"] = {
568
+ "user_id": "user_1",
569
+ "thread_id": 0,
570
+ }
571
+
572
+ ###############################################################################
573
+ # 2. Define MessagesState
574
+ ###############################################################################
575
+ class MessagesState(TypedDict):
576
+ """The state of the agent.
577
+
578
+ The key 'messages' uses add_messages as a reducer,
579
+ so each time this state is updated, new messages are appended.
580
+ # See https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers
581
+ """
582
+ messages: Annotated[Sequence[BaseMessage], add_messages]
583
+
584
+
585
+ ###############################################################################
586
+ # 3. Memory Tools
587
+ ###############################################################################
588
+ def save_memory(summary_text: str, *, config: RunnableConfig, store: BaseStore) -> str:
589
+ """Save the given memory for the current user and return the key."""
590
+ user_id = config.get("configurable", {}).get("user_id")
591
+ thread_id = config.get("configurable", {}).get("thread_id")
592
+ namespace = (user_id, "memories")
593
+ memory_id = thread_id
594
+ store.put(namespace, memory_id, {"memory": summary_text})
595
+ return f"Saved to memory key: {memory_id}"
596
+
597
+ def update_memory(state: MessagesState, config: RunnableConfig, *, store: BaseStore):
598
+ # Extract the messages list from the event, handling potential missing key
599
+ messages = state["messages"]
600
+ # Convert LangChain messages to dictionaries before storing
601
+ messages_dict = [{"role": msg.type, "content": msg.content} for msg in messages]
602
+
603
+ # Get the user id from the config
604
+ user_id = config.get("configurable", {}).get("user_id")
605
+ thread_id = config.get("configurable", {}).get("thread_id")
606
+ # Namespace the memory
607
+ namespace = (user_id, "memories")
608
+ # Create a new memory ID
609
+ memory_id = f"{thread_id}"
610
+ store.put(namespace, memory_id, {"memory": messages_dict})
611
+ return f"Saved to memory key: {memory_id}"
612
+
613
+
614
+ # Define a Pydantic schema for the save_memory tool (if needed elsewhere)
615
+ # https://langchain-ai.github.io/langgraphjs/reference/classes/checkpoint.InMemoryStore.html
616
+ class RecallMemory(BaseModel):
617
+ query_text: str = Field(..., title="Search Text", description="The text to search from memories for similar records.")
618
+ k: int = Field(5, title="Number of Results", description="Number of results to retrieve.")
619
+
620
+ def recall_memory(query_text: str, k: int = 5) -> str:
621
+ """Retrieve user memories from in_memory_store."""
622
+ user_id = config.get("configurable", {}).get("user_id")
623
+ memories = [
624
+ m.value["memory"] for m in in_memory_store.search((user_id, "memories"), query=query_text, limit=k)
625
+ if "memory" in m.value
626
+ ]
627
+ return f"User memories: {memories}"
628
+
629
+ # Create a StructuredTool from the function
630
+ recall_memory_tool = StructuredTool.from_function(
631
+ func=recall_memory,
632
+ name="Recall Memory Tool",
633
+ description="""
634
+ Retrieve memories relevant to the user's query.
635
+ """,
636
+ args_schema=RecallMemory,
637
+ response_format="content",
638
+ return_direct=False,
639
+ verbose=False
640
+ )
641
+
642
+ ###############################################################################
643
+ # 4. Summarize Node (using StructuredTool)
644
+ ###############################################################################
645
+ # Define a Pydantic schema for the Summary tool
646
+ class SummariseConversation(BaseModel):
647
+ summary_text: str = Field(..., title="text", description="Write a summary of entire conversation here")
648
+
649
+ def summarise_node(summary_text: str):
650
+ """
651
+ Final node that summarizes the entire conversation for the current thread,
652
+ saves it in memory, increments the thread_id, and ends the conversation.
653
+ Returns a confirmation string.
654
+ """
655
+ user_id = config["configurable"]["user_id"]
656
+ current_thread_id = config["configurable"]["thread_id"]
657
+ new_thread_id = str(int(current_thread_id) + 1)
658
+
659
+ # Prepare configuration for saving memory with updated thread id
660
+ config_for_saving = {
661
+ "configurable": {
662
+ "user_id": user_id,
663
+ "thread_id": new_thread_id
664
+ }
665
+ }
666
+ key = save_memory(summary_text, config=config_for_saving, store=in_memory_store)
667
+ #return f"Summary saved under key: {key}"
668
+
669
+ # Create a StructuredTool from the function (this wraps summarise_node)
670
+ summarise_tool = StructuredTool.from_function(
671
+ func=summarise_node,
672
+ name="Summary Tool",
673
+ description="""
674
+ Summarize the current conversation in no more than
675
+ 1000 words. Also retain any unanswered quiz questions along with
676
+ your internal answers so the next conversation thread can continue.
677
+ Do not reveal solutions to the user yet. Use this tool to save
678
+ the current conversation to memory and then end the conversation.
679
+ """,
680
+ args_schema=SummariseConversation,
681
+ response_format="content",
682
+ return_direct=False,
683
+ verbose=True
684
+ )
685
+
686
+ def call_summary(state: MessagesState, config: RunnableConfig):
687
+ # Convert message dicts to HumanMessage instances if needed.
688
+ system_message="""
689
+ Summarize the current conversation in no more than
690
+ 1000 words. Also retain any unanswered quiz questions along with
691
+ your internal answers.
692
+ """
693
+ messages = []
694
+ for m in state["messages"]:
695
+ if isinstance(m, dict):
696
+ # Use role from dict (defaulting to 'user' if missing)
697
+ messages.append(AIMessage(content=system_message, role=m.get("role", "assistant")))
698
+ else:
699
+ messages.append(m)
700
+
701
+ summaries = llm_with_tools.invoke(messages)
702
+
703
+ summary_content = summaries.content
704
+
705
+ # Call Tool Manually
706
+ message_with_single_tool_call = AIMessage(
707
+ content="",
708
+ tool_calls=[
709
+ {
710
+ "name": "Summary Tool",
711
+ "args": {"summary_text": summary_content},
712
+ "id": "tool_call_id",
713
+ "type": "tool_call",
714
+ }
715
+ ],
716
+ )
717
+
718
+ tool_node.invoke({"messages": [message_with_single_tool_call]})
719
+
720
+
721
+ ###############################################################################
722
+ # 5. Build the Graph
723
+ ###############################################################################
724
+ graph_builder = StateGraph(MessagesState)
725
+
726
+ # Use the built-in ToolNode from langgraph that calls any declared tools.
727
+ tools = [
728
+ mcq_retriever_tool,
729
+ web_extraction_tool,
730
+ ensemble_retriever_tool,
731
+ general_retriever_tool,
732
+ in_memory_retriever_tool,
733
+ recall_memory_tool,
734
+ summarise_tool,
735
+ ]
736
+
737
+ tool_node = ToolNode(tools=tools)
738
+ #end_node = ToolNode(tools=[summarise_tool])
739
+
740
+ # Wrap your model with tools
741
+ llm_with_tools = model.bind_tools(tools)
742
+
743
+ ###############################################################################
744
+ # 6. The agent's main node: call_model
745
+ ###############################################################################
746
+ def call_model(state: MessagesState, config: RunnableConfig):
747
+ """
748
+ The main agent node that calls the LLM with the user + system messages.
749
+ Since our vLLM chat wrapper expects a list of BaseMessage objects,
750
+ we convert any dict messages to HumanMessage objects.
751
+ If the LLM requests a tool call, we'll route to the 'tools' node next
752
+ (depending on the condition).
753
+ """
754
+ # Convert message dicts to HumanMessage instances if needed.
755
+ messages = []
756
+ for m in state["messages"]:
757
+ if isinstance(m, dict):
758
+ # Use role from dict (defaulting to 'user' if missing)
759
+ messages.append(HumanMessage(content=m.get("content", ""), role=m.get("role", "user")))
760
+ else:
761
+ messages.append(m)
762
+
763
+ # Invoke the LLM (with tools) using the converted messages.
764
+ response = llm_with_tools.invoke(messages)
765
+
766
+ return {"messages": [response]}
767
+
768
+
769
+
770
+ def call_summary(state: MessagesState, config: RunnableConfig):
771
+ # Convert message dicts to HumanMessage instances if needed.
772
+ system_message="""
773
+ Summarize the current conversation in no more than
774
+ 1000 words. Also retain any unanswered quiz questions along with
775
+ your internal answers.
776
+ """
777
+ messages = []
778
+ for m in state["messages"]:
779
+ if isinstance(m, dict):
780
+ # Use role from dict (defaulting to 'user' if missing)
781
+ messages.append(AIMessage(content=system_message, role=m.get("role", "assistant")))
782
+ else:
783
+ messages.append(m)
784
+
785
+ summaries = llm_with_tools.invoke(messages)
786
+
787
+ summary_content = summaries.content
788
+
789
+ # Call Tool Manually
790
+ message_with_single_tool_call = AIMessage(
791
+ content="",
792
+ tool_calls=[
793
+ {
794
+ "name": "Summary Tool",
795
+ "args": {"summary_text": summary_content},
796
+ "id": "tool_call_id",
797
+ "type": "tool_call",
798
+ }
799
+ ],
800
+ )
801
+
802
+ tool_node.invoke({"messages": [message_with_single_tool_call]})
803
+
804
+ ###############################################################################
805
+ # 7. Add Nodes & Edges, Then Compile
806
+ ###############################################################################
807
+ graph_builder.add_node("agent", call_model)
808
+ graph_builder.add_node("tools", tool_node)
809
+ #graph_builder.add_node("summary", call_summary)
810
+
811
+ # Entry point
812
+ graph_builder.set_entry_point("agent")
813
+
814
+ # def custom_tools_condition(llm_output: dict) -> str:
815
+ # """Return which node to go to next based on the LLM output."""
816
+
817
+ # # The LLM's JSON might have a field like {"name": "Recall Memory Tool", "arguments": {...}}.
818
+ # tool_name = llm_output.get("name", None)
819
+
820
+ # # If the LLM calls "Summary Tool", jump directly to the 'summary' node
821
+ # if tool_name == "Summary Tool":
822
+ # return "summary"
823
+
824
+ # # If the LLM calls any other recognized tool, go to 'tools'
825
+ # valid_tool_names = [t.name for t in tools] # all tools in the main tool_node
826
+ # if tool_name in valid_tool_names:
827
+ # return "tools"
828
+
829
+ # # If there's no recognized tool name, assume we're done => go to summary
830
+ # return "__end__"
831
+
832
+ # graph_builder.add_conditional_edges(
833
+ # "agent",
834
+ # custom_tools_condition,
835
+ # {
836
+ # "tools": "tools",
837
+ # "summary": "summary",
838
+ # "__end__": "summary",
839
+ # }
840
+ # )
841
+
842
+ # If LLM requests a tool, go to "tools", otherwise go to "summary"
843
+ graph_builder.add_conditional_edges("agent", tools_condition)
844
+ #graph_builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", "__end__": "summary"})
845
+ #graph_builder.add_conditional_edges("agent", lambda llm_output: "tools" if llm_output.get("name", None) in [t.name for t in tools] else "summary", {"tools": "tools", "__end__": "summary"}
846
+
847
+ # If we used a tool, return to the agent for final answer or more tools
848
+ graph_builder.add_edge("tools", "agent")
849
+ #graph_builder.add_edge("agent", "summary")
850
+ #graph_builder.set_finish_point("summary")
851
+
852
+ # Compile the graph with checkpointing and persistent store
853
+ graph = graph_builder.compile(checkpointer=checkpointer, store=in_memory_store)
854
+
855
+ #from langgraph.prebuilt import create_react_agent
856
+ #graph = create_react_agent(llm_with_tools, tools=tool_node, checkpointer=checkpointer, store=in_memory_store)
857
+
858
+ #from IPython.display import Image, display
859
+ #display(Image(graph.get_graph().draw_mermaid_png()))
860
+
861
+
862
+ ########################################
863
+ # Gradio Chatbot Application
864
+ ########################################
865
+
866
+ import gradio as gr
867
+ from gradio import ChatMessage
868
+
869
+ system_prompt = "You are a helpful Assistant. Always use the tools {tools}."
870
+
871
+ ########################################
872
+ # Upload_documents
873
+ ########################################
874
+
875
+ def upload_documents(file_list: List):
876
+ """
877
+ Load documents into in-memory vector store.
878
+ """
879
+ _documents = []
880
+
881
+ for doc_path in file_list:
882
+ _documents.extend(load_file(doc_path))
883
+
884
+ # Split the documents into smaller chunks for better embedding coverage
885
+ splitter = RecursiveCharacterTextSplitter(
886
+ chunk_size=300, # Split into chunks of 512 characters
887
+ chunk_overlap=50, # Overlap by 50 characters
888
+ add_start_index=True
889
+ )
890
+ chunked_texts = splitter.split_documents(_documents)
891
+ in_memory_vs.add_documents(chunked_texts)
892
+ return f"Uploaded {len(file_list)} documents into in-memory vector store."
893
+
894
+
895
+ ########################################
896
+ # Submit_queries (ChatInterface Function)
897
+ ########################################
898
+ def submit_queries(message, _messages):
899
+ """
900
+ - message: dict with {"text": ..., "files": [...]}
901
+ - history: list of ChatMessage
902
+ """
903
+ _messages=[]
904
+ user_text = message.get("text", "")
905
+ user_files = message.get("files", [])
906
+
907
+ # Process user-uploaded files
908
+ if user_files:
909
+ for file_obj in user_files:
910
+ _messages.append(ChatMessage(role="user", content=f"Uploaded file: {file_obj}"))
911
+ upload_response = upload_documents(user_files)
912
+ _messages.append(ChatMessage(role="assistant", content=upload_response))
913
+ yield _messages
914
+ return # Exit early since we don't need to process text or call the LLM
915
+
916
+ # Append user text if present
917
+ if user_text:
918
+ events = graph.stream(
919
+ {
920
+ "messages": [
921
+ {"role": "system", "content": system_prompt},
922
+ {"role": "user", "content": user_text},
923
+ ]
924
+ },
925
+ config,
926
+ stream_mode="values"
927
+ )
928
+
929
+ for event in events:
930
+ response = event["messages"][-1]
931
+ if isinstance(response, AIMessage):
932
+ if "tool_calls" in response.additional_kwargs:
933
+ _messages.append(
934
+ ChatMessage(role="assistant",
935
+ content=str(response.tool_calls[0]["args"]),
936
+ metadata={"title": str(response.tool_calls[0]["name"]),
937
+ "id": config["configurable"]["thread_id"]
938
+ }
939
+ ))
940
+ yield _messages
941
+ else:
942
+ _messages.append(ChatMessage(role="assistant",
943
+ content=response.content,
944
+ metadata={"id": config["configurable"]["thread_id"]
945
+ }
946
+ ))
947
+ yield _messages
948
+ return _messages
949
+
950
+
951
+
952
+
953
+ ########################################
954
+ # 3) Save Chat History
955
+ ########################################
956
+ CHAT_HISTORY_FILE = "chat_history.json"
957
+
958
+ def save_chat_history(history):
959
+ """
960
+ Saves the chat history into a JSON file.
961
+ """
962
+ session_history = [
963
+ {
964
+ "role": "user" if msg.is_user else "assistant",
965
+ "content": msg.content
966
+ }
967
+ for msg in history
968
+ ]
969
+ with open(CHAT_HISTORY_FILE, "w", encoding="utf-8") as f:
970
+ json.dump(session_history, f, ensure_ascii=False, indent=4)
971
+
972
+
973
+ ########################################
974
+ # 6) Main Gradio Interface
975
+ ########################################
976
+ with gr.Blocks(theme="ocean") as AI_Tutor:
977
+ gr.Markdown("# AI Tutor Chatbot (Gradio App)")
978
+
979
+ # Primary Chat Interface
980
+ chat_interface = gr.ChatInterface(
981
+ fn=submit_queries,
982
+ type="messages",
983
+ chatbot=gr.Chatbot(
984
+ label="Chat Window",
985
+ height=500
986
+ ),
987
+ textbox=gr.MultimodalTextbox(
988
+ file_count="multiple",
989
+ file_types=None,
990
+ sources="upload",
991
+ label="Type your query here:",
992
+ placeholder="Enter your question...",
993
+ ),
994
+ title="AI Tutor Chatbot",
995
+ description="Ask me anything about Artificial Intelligence!",
996
+ multimodal=True,
997
+ save_history=True,
998
+ )
999
+
1000
+
1001
+ if __name__ == "__main__":
1002
+ AI_Tutor.launch()
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ requests
3
+ #langchain-groq
4
+ langchain-openai
5
+ torch
6
+ vllm
7
+ accelerate
8
+ bitsandbytes
9
+
10
+ # LangChain and related dependencies
11
+ langchain
12
+ langchain-core
13
+ langchain-text-splitters
14
+ langchain-community
15
+ langgraph
16
+ chromadb
17
+ langchain-chroma
18
+ #langsmith
19
+
20
+ # Document processing
21
+ docling
22
+ langchain-docling
23
+
24
+ # Local LLM and other utilities
25
+ #llama-cpp-python
26
+ langchain_huggingface
27
+ transformers
28
+ sentence_transformers
29
+ huggingface_hub
30
+