Spaces:
Sleeping
Sleeping
Upload
Browse files- Dockerfile +31 -0
- app.py +1002 -0
- requirements.txt +30 -0
Dockerfile
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
# Install vLLM dependencies
|
4 |
+
RUN pip install vllm gradio bitsandbytes transformers accelerate wget
|
5 |
+
|
6 |
+
# Copy your Gradio app files
|
7 |
+
COPY app.py .
|
8 |
+
COPY requirements.txt .
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
|
11 |
+
# Download chat template and model
|
12 |
+
RUN wget -O /tmp/tool_chat_template_llama3.1_json.jinja \
|
13 |
+
https://github.com/vllm-project/vllm/raw/refs/heads/main/examples/tool_chat_template_llama3.1_json.jinja && \
|
14 |
+
huggingface-cli download --resume-download unsloth/llama-3-8b-Instruct-bnb-4bit --local-dir /app/models
|
15 |
+
|
16 |
+
# Expose Gradio port
|
17 |
+
EXPOSE 7860
|
18 |
+
|
19 |
+
# Start vLLM and Gradio
|
20 |
+
CMD vllm.entrypoints.openai.api_server \
|
21 |
+
--model /app/models \
|
22 |
+
--enable-auto-tool-choice \
|
23 |
+
--tool-call-parser llama3_json \
|
24 |
+
--chat-template /tmp/tool_chat_template_llama3.1_json.jinja \
|
25 |
+
--quantization bitsandbytes \
|
26 |
+
--load-format bitsandbytes \
|
27 |
+
--dtype half \
|
28 |
+
--max-model-len 8192 \
|
29 |
+
--download-dir models/vllm \
|
30 |
+
--host 0.0.0.0 \
|
31 |
+
--port 8000 & python app.py
|
app.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import StringIO
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import os
|
5 |
+
#from huggingface_hub import login
|
6 |
+
import gradio as gr
|
7 |
+
import json
|
8 |
+
import csv
|
9 |
+
import hashlib
|
10 |
+
import uuid
|
11 |
+
import logging
|
12 |
+
from typing import Annotated, List, Dict, Sequence, TypedDict
|
13 |
+
|
14 |
+
# LangChain & related imports
|
15 |
+
from langchain_core.runnables import RunnableConfig
|
16 |
+
from langchain_core.tools import tool, StructuredTool
|
17 |
+
from pydantic import BaseModel, Field
|
18 |
+
|
19 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
20 |
+
from langchain_chroma import Chroma
|
21 |
+
from langchain_core.documents import Document
|
22 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
23 |
+
from langchain.retrievers import EnsembleRetriever
|
24 |
+
|
25 |
+
# Extraction for Documents
|
26 |
+
from langchain_docling.loader import ExportType
|
27 |
+
from langchain_docling import DoclingLoader
|
28 |
+
from docling.chunking import HybridChunker
|
29 |
+
|
30 |
+
# Extraction for HTML
|
31 |
+
from langchain_community.document_loaders import WebBaseLoader
|
32 |
+
from urllib.parse import urlparse
|
33 |
+
|
34 |
+
#from langchain_groq import ChatGroq
|
35 |
+
from langchain_openai import ChatOpenAI
|
36 |
+
from langgraph.prebuilt import InjectedStore
|
37 |
+
from langgraph.store.base import BaseStore
|
38 |
+
from langgraph.store.memory import InMemoryStore
|
39 |
+
from langgraph.checkpoint.memory import MemorySaver
|
40 |
+
from langchain.embeddings import init_embeddings
|
41 |
+
from langgraph.graph import StateGraph
|
42 |
+
from langgraph.graph.message import add_messages
|
43 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
44 |
+
from langchain_core.messages import (
|
45 |
+
SystemMessage,
|
46 |
+
AIMessage,
|
47 |
+
HumanMessage,
|
48 |
+
BaseMessage,
|
49 |
+
ToolMessage,
|
50 |
+
)
|
51 |
+
|
52 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
53 |
+
logger = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
# Suppress all library logs at or below WARNING for user experience:
|
56 |
+
logging.disable(logging.WARNING)
|
57 |
+
|
58 |
+
|
59 |
+
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
60 |
+
|
61 |
+
# =============================================================================
|
62 |
+
# Document Extraction Functions
|
63 |
+
# =============================================================================
|
64 |
+
|
65 |
+
def extract_documents(doc_path: str) -> List[str]:
|
66 |
+
"""
|
67 |
+
Recursively collects all file paths from folder 'doc_path'.
|
68 |
+
Used by ExtractDocument.load_files() to find documents to parse.
|
69 |
+
"""
|
70 |
+
extracted_docs = []
|
71 |
+
|
72 |
+
for root, _, files in os.walk(doc_path):
|
73 |
+
for file in files:
|
74 |
+
file_path = os.path.join(root, file)
|
75 |
+
extracted_docs.append(file_path)
|
76 |
+
return extracted_docs
|
77 |
+
|
78 |
+
|
79 |
+
def _generate_uuid(page_content: str) -> str:
|
80 |
+
"""Generate a UUID for a chunk of text using MD5 hashing."""
|
81 |
+
md5_hash = hashlib.md5(page_content.encode()).hexdigest()
|
82 |
+
return str(uuid.UUID(md5_hash[0:32]))
|
83 |
+
|
84 |
+
|
85 |
+
def load_file(file_path: str) -> List[Document]:
|
86 |
+
"""
|
87 |
+
Load a file from the given path and return a list of Document objects.
|
88 |
+
"""
|
89 |
+
_documents = []
|
90 |
+
|
91 |
+
# Load the file and extract the text chunks
|
92 |
+
try:
|
93 |
+
loader = DoclingLoader(
|
94 |
+
file_path = file_path,
|
95 |
+
export_type = ExportType.DOC_CHUNKS,
|
96 |
+
chunker = HybridChunker(tokenizer=EMBED_MODEL_ID),
|
97 |
+
)
|
98 |
+
docs = loader.load()
|
99 |
+
logger.info(f"Total parsed doc-chunks: {len(docs)} from Source: {file_path}")
|
100 |
+
|
101 |
+
for d in docs:
|
102 |
+
# Tag each document's chunk with the source file and a unique ID
|
103 |
+
doc = Document(
|
104 |
+
page_content=d.page_content,
|
105 |
+
metadata={
|
106 |
+
"source": file_path,
|
107 |
+
"doc_id": _generate_uuid(d.page_content),
|
108 |
+
"source_type": "file",
|
109 |
+
}
|
110 |
+
)
|
111 |
+
_documents.append(doc)
|
112 |
+
logger.info(f"Total generated LangChain document chunks: {len(_documents)}\n.")
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
logger.error(f"Error loading file: {file_path}. Exception: {e}\n.")
|
116 |
+
|
117 |
+
return _documents
|
118 |
+
|
119 |
+
|
120 |
+
# Define function to load documents from a folder
|
121 |
+
def load_files_from_folder(doc_path: str) -> List[Document]:
|
122 |
+
"""
|
123 |
+
Load documents from the given folder path and return a list of Document objects.
|
124 |
+
"""
|
125 |
+
_documents = []
|
126 |
+
# Extract all files path from the given folder
|
127 |
+
extracted_docs = extract_documents(doc_path)
|
128 |
+
|
129 |
+
# Iterate through each document and extract the text chunks
|
130 |
+
for file_path in extracted_docs:
|
131 |
+
_documents.extend(load_file(file_path))
|
132 |
+
|
133 |
+
return _documents
|
134 |
+
|
135 |
+
# =============================================================================
|
136 |
+
# Load structured data in csv file to LangChain Document format
|
137 |
+
def load_mcq_csvfiles(file_path: str) -> List[Document]:
|
138 |
+
"""
|
139 |
+
Load structured data in mcq csv file from the given file path and return a list of Document object.
|
140 |
+
Expected format: each row of csv is comma separated into "mcq_number", "mcq_type", "text_content"
|
141 |
+
"""
|
142 |
+
_documents = []
|
143 |
+
|
144 |
+
# iterate through each csv file and load each row into _dict_per_question format
|
145 |
+
# Ensure we process only CSV files
|
146 |
+
if not file_path.endswith(".csv"):
|
147 |
+
return _documents # Skip non-CSV files
|
148 |
+
try:
|
149 |
+
# Open and read the CSV file
|
150 |
+
with open(file_path, mode='r', encoding='utf-8') as file:
|
151 |
+
reader = csv.DictReader(file)
|
152 |
+
for row in reader:
|
153 |
+
# Ensure required columns exist in the row
|
154 |
+
if not all(k in row for k in ["mcq_number", "mcq_type", "text_content"]): # Ensure required columns exist and exclude header
|
155 |
+
logger.error(f"Skipping row due to missing fields: {row}")
|
156 |
+
continue
|
157 |
+
# Tag each row of csv is comma separated into "mcq_number", "mcq_type", "text_content"
|
158 |
+
doc = Document(
|
159 |
+
page_content = row["text_content"], # text_content segment is separated by "|"
|
160 |
+
metadata={
|
161 |
+
"source": f"{file_path}_{row['mcq_number']}", # file_path + mcq_number
|
162 |
+
"doc_id": _generate_uuid(f"{file_path}_{row['mcq_number']}"), # Unique ID
|
163 |
+
"source_type": row["mcq_type"], # MCQ type
|
164 |
+
}
|
165 |
+
)
|
166 |
+
_documents.append(doc)
|
167 |
+
logger.info(f"Successfully loaded {len(_documents)} LangChain document chunks from {file_path}.")
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
logger.error(f"Error loading file: {file_path}. Exception: {e}\n.")
|
171 |
+
|
172 |
+
return _documents
|
173 |
+
|
174 |
+
# Define function to load documents from a folder for structured data in csv file
|
175 |
+
def load_files_from_folder_mcq(doc_path: str) -> List[Document]:
|
176 |
+
"""
|
177 |
+
Load mcq csv file from the given folder path and return a list of Document objects.
|
178 |
+
"""
|
179 |
+
_documents = []
|
180 |
+
# Extract all files path from the given folder
|
181 |
+
extracted_docs = [
|
182 |
+
os.path.join(doc_path, file) for file in os.listdir(doc_path)
|
183 |
+
if file.endswith(".csv") # Process only CSV files
|
184 |
+
]
|
185 |
+
|
186 |
+
# Iterate through each document and extract the text chunks
|
187 |
+
for file_path in extracted_docs:
|
188 |
+
_documents.extend(load_mcq_csvfiles(file_path))
|
189 |
+
|
190 |
+
return _documents
|
191 |
+
|
192 |
+
|
193 |
+
# =============================================================================
|
194 |
+
# Website Extraction Functions
|
195 |
+
# =============================================================================
|
196 |
+
def _generate_uuid(page_content: str) -> str:
|
197 |
+
"""Generate a UUID for a chunk of text using MD5 hashing."""
|
198 |
+
md5_hash = hashlib.md5(page_content.encode()).hexdigest()
|
199 |
+
return str(uuid.UUID(md5_hash[0:32]))
|
200 |
+
|
201 |
+
def ensure_scheme(url):
|
202 |
+
parsed_url = urlparse(url)
|
203 |
+
if not parsed_url.scheme:
|
204 |
+
return 'http://' + url # Default to http, or use 'https://' if preferred
|
205 |
+
return url
|
206 |
+
|
207 |
+
def extract_html(url: List[str]) -> List[Document]:
|
208 |
+
if isinstance(url, str):
|
209 |
+
url = [url]
|
210 |
+
"""
|
211 |
+
Extracts text from the HTML content of web pages listed in 'web_path'.
|
212 |
+
Returns a list of LangChain 'Document' objects.
|
213 |
+
"""
|
214 |
+
# Ensure all URLs have a scheme
|
215 |
+
web_paths = [ensure_scheme(u) for u in url]
|
216 |
+
|
217 |
+
loader = WebBaseLoader(web_paths)
|
218 |
+
loader.requests_per_second = 1
|
219 |
+
docs = loader.load()
|
220 |
+
|
221 |
+
# Iterate through each document, clean the content, removing excessive line return and store it in a LangChain Document
|
222 |
+
_documents = []
|
223 |
+
for doc in docs:
|
224 |
+
# Clean the concent
|
225 |
+
doc.page_content = doc.page_content.strip()
|
226 |
+
doc.page_content = doc.page_content.replace("\n", " ")
|
227 |
+
doc.page_content = doc.page_content.replace("\r", " ")
|
228 |
+
doc.page_content = doc.page_content.replace("\t", " ")
|
229 |
+
doc.page_content = doc.page_content.replace(" ", " ")
|
230 |
+
doc.page_content = doc.page_content.replace(" ", " ")
|
231 |
+
|
232 |
+
# Store it in a LangChain Document
|
233 |
+
web_doc = Document(
|
234 |
+
page_content=doc.page_content,
|
235 |
+
metadata={
|
236 |
+
"source": doc.metadata.get("source"),
|
237 |
+
"doc_id": _generate_uuid(doc.page_content),
|
238 |
+
"source_type": "web"
|
239 |
+
}
|
240 |
+
)
|
241 |
+
_documents.append(web_doc)
|
242 |
+
return _documents
|
243 |
+
|
244 |
+
# =============================================================================
|
245 |
+
# Vector Store Initialisation
|
246 |
+
# =============================================================================
|
247 |
+
|
248 |
+
embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_ID)
|
249 |
+
|
250 |
+
# Initialise vector stores
|
251 |
+
general_vs = Chroma(
|
252 |
+
collection_name="general_vstore",
|
253 |
+
embedding_function=embedding_model,
|
254 |
+
persist_directory="./general_db"
|
255 |
+
)
|
256 |
+
|
257 |
+
mcq_vs = Chroma(
|
258 |
+
collection_name="mcq_vstore",
|
259 |
+
embedding_function=embedding_model,
|
260 |
+
persist_directory="./mcq_db"
|
261 |
+
)
|
262 |
+
|
263 |
+
in_memory_vs = Chroma(
|
264 |
+
collection_name="in_memory_vstore",
|
265 |
+
embedding_function=embedding_model
|
266 |
+
)
|
267 |
+
|
268 |
+
# Split the documents into smaller chunks for better embedding coverage
|
269 |
+
def split_text_into_chunks(docs: List[Document]) -> List[Document]:
|
270 |
+
"""
|
271 |
+
Splits a list of Documents into smaller text chunks using
|
272 |
+
RecursiveCharacterTextSplitter while preserving metadata.
|
273 |
+
Returns a list of Document objects.
|
274 |
+
"""
|
275 |
+
if not docs:
|
276 |
+
return []
|
277 |
+
splitter = RecursiveCharacterTextSplitter(
|
278 |
+
chunk_size=1000, # Split into chunks of 1000 characters
|
279 |
+
chunk_overlap=200, # Overlap by 200 characters
|
280 |
+
add_start_index=True
|
281 |
+
)
|
282 |
+
chunked_docs = splitter.split_documents(docs)
|
283 |
+
return chunked_docs # List of Document objects
|
284 |
+
|
285 |
+
|
286 |
+
# =============================================================================
|
287 |
+
# Retrieval Tools
|
288 |
+
# =============================================================================
|
289 |
+
|
290 |
+
# Define a simple similarity search retrieval tool on msq_vs
|
291 |
+
class MCQRetrievalTool(BaseModel):
|
292 |
+
input: str = Field(..., title="input", description="Search topic.")
|
293 |
+
k: int = Field(2, title="Number of Results", description="The number of results to retrieve.")
|
294 |
+
|
295 |
+
def mcq_retriever(input: str, k: int = 2) -> List[str]:
|
296 |
+
# Retrieve the top k most similar mcq question documents from the vector store
|
297 |
+
docs_func = mcq_vs.as_retriever(
|
298 |
+
search_type="similarity",
|
299 |
+
search_kwargs={
|
300 |
+
'k': k,
|
301 |
+
'filter':{"source_type": "mcq_question"}
|
302 |
+
},
|
303 |
+
)
|
304 |
+
docs_qns = docs_func.invoke(input, k=k)
|
305 |
+
|
306 |
+
# Extract the document IDs from the retrieved documents
|
307 |
+
doc_ids = [d.metadata.get("doc_id") for d in docs_qns if "doc_id" in d.metadata]
|
308 |
+
|
309 |
+
# Retrieve full documents based on the doc_ids
|
310 |
+
docs = mcq_vs.get(where = {'doc_id': {"$in":doc_ids}})
|
311 |
+
|
312 |
+
qns_list = {}
|
313 |
+
for i, d in enumerate(docs['metadatas']):
|
314 |
+
qns_list[d['source'] + " " + d['source_type']] = docs['documents'][i]
|
315 |
+
|
316 |
+
return qns_list
|
317 |
+
|
318 |
+
# Create a StructuredTool from the function
|
319 |
+
mcq_retriever_tool = StructuredTool.from_function(
|
320 |
+
func = mcq_retriever,
|
321 |
+
name = "MCQ Retrieval Tool",
|
322 |
+
description = (
|
323 |
+
"""
|
324 |
+
Use this tool to retrieve MCQ questions set when Human asks to generate a quiz related to a topic.
|
325 |
+
DO NOT GIVE THE ANSWERS to Human before Human has answered all the questions.
|
326 |
+
|
327 |
+
If Human give answers for questions you do not know, SAY you do not have the questions for the answer
|
328 |
+
and ASK if the Human want you to generate a new quiz and then SAVE THE QUIZ with Summary Tool before ending the conversation.
|
329 |
+
|
330 |
+
|
331 |
+
Input must be a JSON string with the schema:
|
332 |
+
- input (str): The search topic to retrieve MCQ questions set related to the topic.
|
333 |
+
- k (int): Number of question set to retrieve.
|
334 |
+
Example usage: input='What is AI?', k=5
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
- A dict of MCQ questions:
|
338 |
+
Key: 'metadata of question' e.g. './Documents/mcq/mcq.csv_Qn31 mcq_question' with suffix ['question', 'answer', 'answer_reason', 'options', 'wrong_options_reason']
|
339 |
+
Value: Text Content
|
340 |
+
|
341 |
+
"""
|
342 |
+
),
|
343 |
+
args_schema = MCQRetrievalTool,
|
344 |
+
response_format="content",
|
345 |
+
return_direct = False, # Return the response as a list of strings
|
346 |
+
verbose = False # To log tool's progress
|
347 |
+
)
|
348 |
+
|
349 |
+
# -----------------------------------------------------------------------------
|
350 |
+
|
351 |
+
# Retrieve more documents with higher diversity using MMR (Maximal Marginal Relevance) from the general vector store
|
352 |
+
# Useful if the dataset has many similar documents
|
353 |
+
class GenRetrievalTool(BaseModel):
|
354 |
+
input: str = Field(..., title="input", description="User query.")
|
355 |
+
k: int = Field(2, title="Number of Results", description="The number of results to retrieve.")
|
356 |
+
|
357 |
+
def gen_retriever(input: str, k: int = 2) -> List[str]:
|
358 |
+
# Use retriever of vector store to retrieve documents
|
359 |
+
docs_func = general_vs.as_retriever(
|
360 |
+
search_type="mmr",
|
361 |
+
search_kwargs = {'k': k, 'lambda_mult': 0.25}
|
362 |
+
)
|
363 |
+
docs = docs_func.invoke(input, k=k)
|
364 |
+
return [d.page_content for d in docs]
|
365 |
+
|
366 |
+
# Create a StructuredTool from the function
|
367 |
+
general_retriever_tool = StructuredTool.from_function(
|
368 |
+
func = gen_retriever,
|
369 |
+
name = "Assistant References Retrieval Tool",
|
370 |
+
description = (
|
371 |
+
"""
|
372 |
+
Use this tool to retrieve reference information from Assistant reference database for Human queries related to a topic or
|
373 |
+
and when Human asked to generate guides to learn or study about a topic.
|
374 |
+
|
375 |
+
Input must be a JSON string with the schema:
|
376 |
+
- input (str): The user query.
|
377 |
+
- k (int): Number of results to retrieve.
|
378 |
+
Example usage: input='What is AI?', k=5
|
379 |
+
Returns:
|
380 |
+
- A list of retrieved document's content string.
|
381 |
+
"""
|
382 |
+
),
|
383 |
+
args_schema = GenRetrievalTool,
|
384 |
+
response_format="content",
|
385 |
+
return_direct = False, # Return the content of the documents
|
386 |
+
verbose = False # To log tool's progress
|
387 |
+
)
|
388 |
+
|
389 |
+
# -----------------------------------------------------------------------------
|
390 |
+
|
391 |
+
# Retrieve more documents with higher diversity using MMR (Maximal Marginal Relevance) from the in-memory vector store
|
392 |
+
# Query in-memory vector store only
|
393 |
+
class InMemoryRetrievalTool(BaseModel):
|
394 |
+
input: str = Field(..., title="input", description="User query.")
|
395 |
+
k: int = Field(2, title="Number of Results", description="The number of results to retrieve.")
|
396 |
+
|
397 |
+
def in_memory_retriever(input: str, k: int = 2) -> List[str]:
|
398 |
+
# Use retriever of vector store to retrieve documents
|
399 |
+
docs_func = in_memory_vs.as_retriever(
|
400 |
+
search_type="mmr",
|
401 |
+
search_kwargs = {'k': k, 'lambda_mult': 0.25}
|
402 |
+
)
|
403 |
+
docs = docs_func.invoke(input, k=k)
|
404 |
+
return [d.page_content for d in docs]
|
405 |
+
|
406 |
+
# Create a StructuredTool from the function
|
407 |
+
in_memory_retriever_tool = StructuredTool.from_function(
|
408 |
+
func = in_memory_retriever,
|
409 |
+
name = "In-Memory Retrieval Tool",
|
410 |
+
description = (
|
411 |
+
"""
|
412 |
+
Use this tool when Human ask Assistant to retrieve information from documents that Human has uploaded.
|
413 |
+
|
414 |
+
Input must be a JSON string with the schema:
|
415 |
+
- input (str): The user query.
|
416 |
+
- k (int): Number of results to retrieve.
|
417 |
+
"""
|
418 |
+
),
|
419 |
+
args_schema = InMemoryRetrievalTool,
|
420 |
+
response_format="content",
|
421 |
+
return_direct = False, # Whether to return the tool’s output directly
|
422 |
+
verbose = False # To log tool's progress
|
423 |
+
)
|
424 |
+
|
425 |
+
# -----------------------------------------------------------------------------
|
426 |
+
|
427 |
+
# Web Extraction Tool
|
428 |
+
class WebExtractionRequest(BaseModel):
|
429 |
+
input: str = Field(..., title="input", description="Search text.")
|
430 |
+
url: str = Field(
|
431 |
+
...,
|
432 |
+
title="url",
|
433 |
+
description="Web URL(s) to extract content from. If multiple URLs, separate them with a comma."
|
434 |
+
)
|
435 |
+
k: int = Field(5, title="Number of Results", description="The number of results to retrieve.")
|
436 |
+
|
437 |
+
# Extract content from a web URL, load into in_memory_vstore
|
438 |
+
def extract_web_path_tool(input: str, url: str, k: int = 5) -> List[str]:
|
439 |
+
if isinstance(url, str):
|
440 |
+
url = [url]
|
441 |
+
"""
|
442 |
+
Extract content from the web URLs based on user's input.
|
443 |
+
Args:
|
444 |
+
- input: The input text to search for.
|
445 |
+
- url: URLs to extract content from.
|
446 |
+
- k: Number of results to retrieve.
|
447 |
+
Returns:
|
448 |
+
- A list of retrieved document's content string.
|
449 |
+
"""
|
450 |
+
# Extract content from the web
|
451 |
+
html_docs = extract_html(url)
|
452 |
+
if not html_docs:
|
453 |
+
return f"No content extracted from {url}."
|
454 |
+
|
455 |
+
# Split the documents into smaller chunks for better embedding coverage
|
456 |
+
chunked_texts = split_text_into_chunks(html_docs)
|
457 |
+
in_memory_vs.add_documents(chunked_texts) # Add the chunked texts to the in-memory vector store
|
458 |
+
|
459 |
+
# Extract content from the in-memory vector store
|
460 |
+
# Use retriever of vector store to retrieve documents
|
461 |
+
docs_func = in_memory_vs.as_retriever(
|
462 |
+
search_type="mmr",
|
463 |
+
search_kwargs={
|
464 |
+
'k': k,
|
465 |
+
'lambda_mult': 0.25,
|
466 |
+
'filter':{"source": {"$in": url}}
|
467 |
+
},
|
468 |
+
)
|
469 |
+
docs = docs_func.invoke(input, k=k)
|
470 |
+
return [d.page_content for d in docs]
|
471 |
+
|
472 |
+
# Create a StructuredTool from the function
|
473 |
+
web_extraction_tool = StructuredTool.from_function(
|
474 |
+
func = extract_web_path_tool,
|
475 |
+
name = "Web Extraction Tool",
|
476 |
+
description = (
|
477 |
+
"Assistant should use this tool to extract content from web URLs based on user's input, "
|
478 |
+
"Web extraction is initially load into database and then return k: Number of results to retrieve"
|
479 |
+
),
|
480 |
+
args_schema = WebExtractionRequest,
|
481 |
+
return_direct = False, # Whether to return the tool’s output directly
|
482 |
+
verbose = False # To log tool's progress
|
483 |
+
)
|
484 |
+
|
485 |
+
# -----------------------------------------------------------------------------
|
486 |
+
|
487 |
+
# Ensemble Retrieval from General and In-Memory Vector Stores
|
488 |
+
class EnsembleRetrievalTool(BaseModel):
|
489 |
+
input: str = Field(..., title="input", description="User query.")
|
490 |
+
k: int = Field(5, title="Number of Results", description="Number of results.")
|
491 |
+
|
492 |
+
def ensemble_retriever(input: str, k: int = 5) -> List[str]:
|
493 |
+
# Use retriever of vector store to retrieve documents
|
494 |
+
general_retrieval = general_vs.as_retriever(
|
495 |
+
search_type="mmr",
|
496 |
+
search_kwargs = {'k': k, 'lambda_mult': 0.25}
|
497 |
+
)
|
498 |
+
in_memory_retrieval = in_memory_vs.as_retriever(
|
499 |
+
search_type="mmr",
|
500 |
+
search_kwargs = {'k': k, 'lambda_mult': 0.25}
|
501 |
+
)
|
502 |
+
|
503 |
+
ensemble_retriever = EnsembleRetriever(
|
504 |
+
retrievers=[general_retrieval, in_memory_retrieval],
|
505 |
+
weights=[0.5, 0.5]
|
506 |
+
)
|
507 |
+
docs = ensemble_retriever.invoke(input)
|
508 |
+
return [d.page_content for d in docs]
|
509 |
+
|
510 |
+
# Create a StructuredTool from the function
|
511 |
+
ensemble_retriever_tool = StructuredTool.from_function(
|
512 |
+
func = ensemble_retriever,
|
513 |
+
name = "Ensemble Retriever Tool",
|
514 |
+
description = (
|
515 |
+
"""
|
516 |
+
Use this tool to retrieve information from reference database and
|
517 |
+
extraction of documents that Human has uploaded.
|
518 |
+
|
519 |
+
Input must be a JSON string with the schema:
|
520 |
+
- input (str): The user query.
|
521 |
+
- k (int): Number of results to retrieve.
|
522 |
+
"""
|
523 |
+
),
|
524 |
+
args_schema = EnsembleRetrievalTool,
|
525 |
+
response_format="content",
|
526 |
+
return_direct = False
|
527 |
+
)
|
528 |
+
|
529 |
+
|
530 |
+
###############################################################################
|
531 |
+
# LLM Model Setup
|
532 |
+
###############################################################################
|
533 |
+
|
534 |
+
TEMPERATURE = 0.5
|
535 |
+
model = ChatOpenAI(
|
536 |
+
model="unsloth/llama-3-8b-Instruct-bnb-4bit",
|
537 |
+
temperature=TEMPERATURE,
|
538 |
+
timeout=None,
|
539 |
+
max_retries=2,
|
540 |
+
api_key="not_required",
|
541 |
+
base_url="http://localhost:8000", # Use the VLLM instance URL
|
542 |
+
verbose=True
|
543 |
+
)
|
544 |
+
|
545 |
+
# model = ChatGroq(
|
546 |
+
# model_name="deepseek-r1-distill-llama-70b",
|
547 |
+
# temperature=TEMPERATURE,
|
548 |
+
# api_key=GROQ_API_KEY,
|
549 |
+
# verbose=True
|
550 |
+
# )
|
551 |
+
|
552 |
+
###############################################################################
|
553 |
+
# 1. Initialize memory + config
|
554 |
+
###############################################################################
|
555 |
+
in_memory_store = InMemoryStore(
|
556 |
+
index={
|
557 |
+
"embed": init_embeddings("huggingface:sentence-transformers/all-MiniLM-L6-v2"),
|
558 |
+
"dims": 384, # Embedding dimensions
|
559 |
+
}
|
560 |
+
)
|
561 |
+
|
562 |
+
# A memory saver to checkpoint conversation states
|
563 |
+
checkpointer = MemorySaver()
|
564 |
+
|
565 |
+
# Initialize config with user & thread info
|
566 |
+
config = {}
|
567 |
+
config["configurable"] = {
|
568 |
+
"user_id": "user_1",
|
569 |
+
"thread_id": 0,
|
570 |
+
}
|
571 |
+
|
572 |
+
###############################################################################
|
573 |
+
# 2. Define MessagesState
|
574 |
+
###############################################################################
|
575 |
+
class MessagesState(TypedDict):
|
576 |
+
"""The state of the agent.
|
577 |
+
|
578 |
+
The key 'messages' uses add_messages as a reducer,
|
579 |
+
so each time this state is updated, new messages are appended.
|
580 |
+
# See https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers
|
581 |
+
"""
|
582 |
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
583 |
+
|
584 |
+
|
585 |
+
###############################################################################
|
586 |
+
# 3. Memory Tools
|
587 |
+
###############################################################################
|
588 |
+
def save_memory(summary_text: str, *, config: RunnableConfig, store: BaseStore) -> str:
|
589 |
+
"""Save the given memory for the current user and return the key."""
|
590 |
+
user_id = config.get("configurable", {}).get("user_id")
|
591 |
+
thread_id = config.get("configurable", {}).get("thread_id")
|
592 |
+
namespace = (user_id, "memories")
|
593 |
+
memory_id = thread_id
|
594 |
+
store.put(namespace, memory_id, {"memory": summary_text})
|
595 |
+
return f"Saved to memory key: {memory_id}"
|
596 |
+
|
597 |
+
def update_memory(state: MessagesState, config: RunnableConfig, *, store: BaseStore):
|
598 |
+
# Extract the messages list from the event, handling potential missing key
|
599 |
+
messages = state["messages"]
|
600 |
+
# Convert LangChain messages to dictionaries before storing
|
601 |
+
messages_dict = [{"role": msg.type, "content": msg.content} for msg in messages]
|
602 |
+
|
603 |
+
# Get the user id from the config
|
604 |
+
user_id = config.get("configurable", {}).get("user_id")
|
605 |
+
thread_id = config.get("configurable", {}).get("thread_id")
|
606 |
+
# Namespace the memory
|
607 |
+
namespace = (user_id, "memories")
|
608 |
+
# Create a new memory ID
|
609 |
+
memory_id = f"{thread_id}"
|
610 |
+
store.put(namespace, memory_id, {"memory": messages_dict})
|
611 |
+
return f"Saved to memory key: {memory_id}"
|
612 |
+
|
613 |
+
|
614 |
+
# Define a Pydantic schema for the save_memory tool (if needed elsewhere)
|
615 |
+
# https://langchain-ai.github.io/langgraphjs/reference/classes/checkpoint.InMemoryStore.html
|
616 |
+
class RecallMemory(BaseModel):
|
617 |
+
query_text: str = Field(..., title="Search Text", description="The text to search from memories for similar records.")
|
618 |
+
k: int = Field(5, title="Number of Results", description="Number of results to retrieve.")
|
619 |
+
|
620 |
+
def recall_memory(query_text: str, k: int = 5) -> str:
|
621 |
+
"""Retrieve user memories from in_memory_store."""
|
622 |
+
user_id = config.get("configurable", {}).get("user_id")
|
623 |
+
memories = [
|
624 |
+
m.value["memory"] for m in in_memory_store.search((user_id, "memories"), query=query_text, limit=k)
|
625 |
+
if "memory" in m.value
|
626 |
+
]
|
627 |
+
return f"User memories: {memories}"
|
628 |
+
|
629 |
+
# Create a StructuredTool from the function
|
630 |
+
recall_memory_tool = StructuredTool.from_function(
|
631 |
+
func=recall_memory,
|
632 |
+
name="Recall Memory Tool",
|
633 |
+
description="""
|
634 |
+
Retrieve memories relevant to the user's query.
|
635 |
+
""",
|
636 |
+
args_schema=RecallMemory,
|
637 |
+
response_format="content",
|
638 |
+
return_direct=False,
|
639 |
+
verbose=False
|
640 |
+
)
|
641 |
+
|
642 |
+
###############################################################################
|
643 |
+
# 4. Summarize Node (using StructuredTool)
|
644 |
+
###############################################################################
|
645 |
+
# Define a Pydantic schema for the Summary tool
|
646 |
+
class SummariseConversation(BaseModel):
|
647 |
+
summary_text: str = Field(..., title="text", description="Write a summary of entire conversation here")
|
648 |
+
|
649 |
+
def summarise_node(summary_text: str):
|
650 |
+
"""
|
651 |
+
Final node that summarizes the entire conversation for the current thread,
|
652 |
+
saves it in memory, increments the thread_id, and ends the conversation.
|
653 |
+
Returns a confirmation string.
|
654 |
+
"""
|
655 |
+
user_id = config["configurable"]["user_id"]
|
656 |
+
current_thread_id = config["configurable"]["thread_id"]
|
657 |
+
new_thread_id = str(int(current_thread_id) + 1)
|
658 |
+
|
659 |
+
# Prepare configuration for saving memory with updated thread id
|
660 |
+
config_for_saving = {
|
661 |
+
"configurable": {
|
662 |
+
"user_id": user_id,
|
663 |
+
"thread_id": new_thread_id
|
664 |
+
}
|
665 |
+
}
|
666 |
+
key = save_memory(summary_text, config=config_for_saving, store=in_memory_store)
|
667 |
+
#return f"Summary saved under key: {key}"
|
668 |
+
|
669 |
+
# Create a StructuredTool from the function (this wraps summarise_node)
|
670 |
+
summarise_tool = StructuredTool.from_function(
|
671 |
+
func=summarise_node,
|
672 |
+
name="Summary Tool",
|
673 |
+
description="""
|
674 |
+
Summarize the current conversation in no more than
|
675 |
+
1000 words. Also retain any unanswered quiz questions along with
|
676 |
+
your internal answers so the next conversation thread can continue.
|
677 |
+
Do not reveal solutions to the user yet. Use this tool to save
|
678 |
+
the current conversation to memory and then end the conversation.
|
679 |
+
""",
|
680 |
+
args_schema=SummariseConversation,
|
681 |
+
response_format="content",
|
682 |
+
return_direct=False,
|
683 |
+
verbose=True
|
684 |
+
)
|
685 |
+
|
686 |
+
def call_summary(state: MessagesState, config: RunnableConfig):
|
687 |
+
# Convert message dicts to HumanMessage instances if needed.
|
688 |
+
system_message="""
|
689 |
+
Summarize the current conversation in no more than
|
690 |
+
1000 words. Also retain any unanswered quiz questions along with
|
691 |
+
your internal answers.
|
692 |
+
"""
|
693 |
+
messages = []
|
694 |
+
for m in state["messages"]:
|
695 |
+
if isinstance(m, dict):
|
696 |
+
# Use role from dict (defaulting to 'user' if missing)
|
697 |
+
messages.append(AIMessage(content=system_message, role=m.get("role", "assistant")))
|
698 |
+
else:
|
699 |
+
messages.append(m)
|
700 |
+
|
701 |
+
summaries = llm_with_tools.invoke(messages)
|
702 |
+
|
703 |
+
summary_content = summaries.content
|
704 |
+
|
705 |
+
# Call Tool Manually
|
706 |
+
message_with_single_tool_call = AIMessage(
|
707 |
+
content="",
|
708 |
+
tool_calls=[
|
709 |
+
{
|
710 |
+
"name": "Summary Tool",
|
711 |
+
"args": {"summary_text": summary_content},
|
712 |
+
"id": "tool_call_id",
|
713 |
+
"type": "tool_call",
|
714 |
+
}
|
715 |
+
],
|
716 |
+
)
|
717 |
+
|
718 |
+
tool_node.invoke({"messages": [message_with_single_tool_call]})
|
719 |
+
|
720 |
+
|
721 |
+
###############################################################################
|
722 |
+
# 5. Build the Graph
|
723 |
+
###############################################################################
|
724 |
+
graph_builder = StateGraph(MessagesState)
|
725 |
+
|
726 |
+
# Use the built-in ToolNode from langgraph that calls any declared tools.
|
727 |
+
tools = [
|
728 |
+
mcq_retriever_tool,
|
729 |
+
web_extraction_tool,
|
730 |
+
ensemble_retriever_tool,
|
731 |
+
general_retriever_tool,
|
732 |
+
in_memory_retriever_tool,
|
733 |
+
recall_memory_tool,
|
734 |
+
summarise_tool,
|
735 |
+
]
|
736 |
+
|
737 |
+
tool_node = ToolNode(tools=tools)
|
738 |
+
#end_node = ToolNode(tools=[summarise_tool])
|
739 |
+
|
740 |
+
# Wrap your model with tools
|
741 |
+
llm_with_tools = model.bind_tools(tools)
|
742 |
+
|
743 |
+
###############################################################################
|
744 |
+
# 6. The agent's main node: call_model
|
745 |
+
###############################################################################
|
746 |
+
def call_model(state: MessagesState, config: RunnableConfig):
|
747 |
+
"""
|
748 |
+
The main agent node that calls the LLM with the user + system messages.
|
749 |
+
Since our vLLM chat wrapper expects a list of BaseMessage objects,
|
750 |
+
we convert any dict messages to HumanMessage objects.
|
751 |
+
If the LLM requests a tool call, we'll route to the 'tools' node next
|
752 |
+
(depending on the condition).
|
753 |
+
"""
|
754 |
+
# Convert message dicts to HumanMessage instances if needed.
|
755 |
+
messages = []
|
756 |
+
for m in state["messages"]:
|
757 |
+
if isinstance(m, dict):
|
758 |
+
# Use role from dict (defaulting to 'user' if missing)
|
759 |
+
messages.append(HumanMessage(content=m.get("content", ""), role=m.get("role", "user")))
|
760 |
+
else:
|
761 |
+
messages.append(m)
|
762 |
+
|
763 |
+
# Invoke the LLM (with tools) using the converted messages.
|
764 |
+
response = llm_with_tools.invoke(messages)
|
765 |
+
|
766 |
+
return {"messages": [response]}
|
767 |
+
|
768 |
+
|
769 |
+
|
770 |
+
def call_summary(state: MessagesState, config: RunnableConfig):
|
771 |
+
# Convert message dicts to HumanMessage instances if needed.
|
772 |
+
system_message="""
|
773 |
+
Summarize the current conversation in no more than
|
774 |
+
1000 words. Also retain any unanswered quiz questions along with
|
775 |
+
your internal answers.
|
776 |
+
"""
|
777 |
+
messages = []
|
778 |
+
for m in state["messages"]:
|
779 |
+
if isinstance(m, dict):
|
780 |
+
# Use role from dict (defaulting to 'user' if missing)
|
781 |
+
messages.append(AIMessage(content=system_message, role=m.get("role", "assistant")))
|
782 |
+
else:
|
783 |
+
messages.append(m)
|
784 |
+
|
785 |
+
summaries = llm_with_tools.invoke(messages)
|
786 |
+
|
787 |
+
summary_content = summaries.content
|
788 |
+
|
789 |
+
# Call Tool Manually
|
790 |
+
message_with_single_tool_call = AIMessage(
|
791 |
+
content="",
|
792 |
+
tool_calls=[
|
793 |
+
{
|
794 |
+
"name": "Summary Tool",
|
795 |
+
"args": {"summary_text": summary_content},
|
796 |
+
"id": "tool_call_id",
|
797 |
+
"type": "tool_call",
|
798 |
+
}
|
799 |
+
],
|
800 |
+
)
|
801 |
+
|
802 |
+
tool_node.invoke({"messages": [message_with_single_tool_call]})
|
803 |
+
|
804 |
+
###############################################################################
|
805 |
+
# 7. Add Nodes & Edges, Then Compile
|
806 |
+
###############################################################################
|
807 |
+
graph_builder.add_node("agent", call_model)
|
808 |
+
graph_builder.add_node("tools", tool_node)
|
809 |
+
#graph_builder.add_node("summary", call_summary)
|
810 |
+
|
811 |
+
# Entry point
|
812 |
+
graph_builder.set_entry_point("agent")
|
813 |
+
|
814 |
+
# def custom_tools_condition(llm_output: dict) -> str:
|
815 |
+
# """Return which node to go to next based on the LLM output."""
|
816 |
+
|
817 |
+
# # The LLM's JSON might have a field like {"name": "Recall Memory Tool", "arguments": {...}}.
|
818 |
+
# tool_name = llm_output.get("name", None)
|
819 |
+
|
820 |
+
# # If the LLM calls "Summary Tool", jump directly to the 'summary' node
|
821 |
+
# if tool_name == "Summary Tool":
|
822 |
+
# return "summary"
|
823 |
+
|
824 |
+
# # If the LLM calls any other recognized tool, go to 'tools'
|
825 |
+
# valid_tool_names = [t.name for t in tools] # all tools in the main tool_node
|
826 |
+
# if tool_name in valid_tool_names:
|
827 |
+
# return "tools"
|
828 |
+
|
829 |
+
# # If there's no recognized tool name, assume we're done => go to summary
|
830 |
+
# return "__end__"
|
831 |
+
|
832 |
+
# graph_builder.add_conditional_edges(
|
833 |
+
# "agent",
|
834 |
+
# custom_tools_condition,
|
835 |
+
# {
|
836 |
+
# "tools": "tools",
|
837 |
+
# "summary": "summary",
|
838 |
+
# "__end__": "summary",
|
839 |
+
# }
|
840 |
+
# )
|
841 |
+
|
842 |
+
# If LLM requests a tool, go to "tools", otherwise go to "summary"
|
843 |
+
graph_builder.add_conditional_edges("agent", tools_condition)
|
844 |
+
#graph_builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", "__end__": "summary"})
|
845 |
+
#graph_builder.add_conditional_edges("agent", lambda llm_output: "tools" if llm_output.get("name", None) in [t.name for t in tools] else "summary", {"tools": "tools", "__end__": "summary"}
|
846 |
+
|
847 |
+
# If we used a tool, return to the agent for final answer or more tools
|
848 |
+
graph_builder.add_edge("tools", "agent")
|
849 |
+
#graph_builder.add_edge("agent", "summary")
|
850 |
+
#graph_builder.set_finish_point("summary")
|
851 |
+
|
852 |
+
# Compile the graph with checkpointing and persistent store
|
853 |
+
graph = graph_builder.compile(checkpointer=checkpointer, store=in_memory_store)
|
854 |
+
|
855 |
+
#from langgraph.prebuilt import create_react_agent
|
856 |
+
#graph = create_react_agent(llm_with_tools, tools=tool_node, checkpointer=checkpointer, store=in_memory_store)
|
857 |
+
|
858 |
+
#from IPython.display import Image, display
|
859 |
+
#display(Image(graph.get_graph().draw_mermaid_png()))
|
860 |
+
|
861 |
+
|
862 |
+
########################################
|
863 |
+
# Gradio Chatbot Application
|
864 |
+
########################################
|
865 |
+
|
866 |
+
import gradio as gr
|
867 |
+
from gradio import ChatMessage
|
868 |
+
|
869 |
+
system_prompt = "You are a helpful Assistant. Always use the tools {tools}."
|
870 |
+
|
871 |
+
########################################
|
872 |
+
# Upload_documents
|
873 |
+
########################################
|
874 |
+
|
875 |
+
def upload_documents(file_list: List):
|
876 |
+
"""
|
877 |
+
Load documents into in-memory vector store.
|
878 |
+
"""
|
879 |
+
_documents = []
|
880 |
+
|
881 |
+
for doc_path in file_list:
|
882 |
+
_documents.extend(load_file(doc_path))
|
883 |
+
|
884 |
+
# Split the documents into smaller chunks for better embedding coverage
|
885 |
+
splitter = RecursiveCharacterTextSplitter(
|
886 |
+
chunk_size=300, # Split into chunks of 512 characters
|
887 |
+
chunk_overlap=50, # Overlap by 50 characters
|
888 |
+
add_start_index=True
|
889 |
+
)
|
890 |
+
chunked_texts = splitter.split_documents(_documents)
|
891 |
+
in_memory_vs.add_documents(chunked_texts)
|
892 |
+
return f"Uploaded {len(file_list)} documents into in-memory vector store."
|
893 |
+
|
894 |
+
|
895 |
+
########################################
|
896 |
+
# Submit_queries (ChatInterface Function)
|
897 |
+
########################################
|
898 |
+
def submit_queries(message, _messages):
|
899 |
+
"""
|
900 |
+
- message: dict with {"text": ..., "files": [...]}
|
901 |
+
- history: list of ChatMessage
|
902 |
+
"""
|
903 |
+
_messages=[]
|
904 |
+
user_text = message.get("text", "")
|
905 |
+
user_files = message.get("files", [])
|
906 |
+
|
907 |
+
# Process user-uploaded files
|
908 |
+
if user_files:
|
909 |
+
for file_obj in user_files:
|
910 |
+
_messages.append(ChatMessage(role="user", content=f"Uploaded file: {file_obj}"))
|
911 |
+
upload_response = upload_documents(user_files)
|
912 |
+
_messages.append(ChatMessage(role="assistant", content=upload_response))
|
913 |
+
yield _messages
|
914 |
+
return # Exit early since we don't need to process text or call the LLM
|
915 |
+
|
916 |
+
# Append user text if present
|
917 |
+
if user_text:
|
918 |
+
events = graph.stream(
|
919 |
+
{
|
920 |
+
"messages": [
|
921 |
+
{"role": "system", "content": system_prompt},
|
922 |
+
{"role": "user", "content": user_text},
|
923 |
+
]
|
924 |
+
},
|
925 |
+
config,
|
926 |
+
stream_mode="values"
|
927 |
+
)
|
928 |
+
|
929 |
+
for event in events:
|
930 |
+
response = event["messages"][-1]
|
931 |
+
if isinstance(response, AIMessage):
|
932 |
+
if "tool_calls" in response.additional_kwargs:
|
933 |
+
_messages.append(
|
934 |
+
ChatMessage(role="assistant",
|
935 |
+
content=str(response.tool_calls[0]["args"]),
|
936 |
+
metadata={"title": str(response.tool_calls[0]["name"]),
|
937 |
+
"id": config["configurable"]["thread_id"]
|
938 |
+
}
|
939 |
+
))
|
940 |
+
yield _messages
|
941 |
+
else:
|
942 |
+
_messages.append(ChatMessage(role="assistant",
|
943 |
+
content=response.content,
|
944 |
+
metadata={"id": config["configurable"]["thread_id"]
|
945 |
+
}
|
946 |
+
))
|
947 |
+
yield _messages
|
948 |
+
return _messages
|
949 |
+
|
950 |
+
|
951 |
+
|
952 |
+
|
953 |
+
########################################
|
954 |
+
# 3) Save Chat History
|
955 |
+
########################################
|
956 |
+
CHAT_HISTORY_FILE = "chat_history.json"
|
957 |
+
|
958 |
+
def save_chat_history(history):
|
959 |
+
"""
|
960 |
+
Saves the chat history into a JSON file.
|
961 |
+
"""
|
962 |
+
session_history = [
|
963 |
+
{
|
964 |
+
"role": "user" if msg.is_user else "assistant",
|
965 |
+
"content": msg.content
|
966 |
+
}
|
967 |
+
for msg in history
|
968 |
+
]
|
969 |
+
with open(CHAT_HISTORY_FILE, "w", encoding="utf-8") as f:
|
970 |
+
json.dump(session_history, f, ensure_ascii=False, indent=4)
|
971 |
+
|
972 |
+
|
973 |
+
########################################
|
974 |
+
# 6) Main Gradio Interface
|
975 |
+
########################################
|
976 |
+
with gr.Blocks(theme="ocean") as AI_Tutor:
|
977 |
+
gr.Markdown("# AI Tutor Chatbot (Gradio App)")
|
978 |
+
|
979 |
+
# Primary Chat Interface
|
980 |
+
chat_interface = gr.ChatInterface(
|
981 |
+
fn=submit_queries,
|
982 |
+
type="messages",
|
983 |
+
chatbot=gr.Chatbot(
|
984 |
+
label="Chat Window",
|
985 |
+
height=500
|
986 |
+
),
|
987 |
+
textbox=gr.MultimodalTextbox(
|
988 |
+
file_count="multiple",
|
989 |
+
file_types=None,
|
990 |
+
sources="upload",
|
991 |
+
label="Type your query here:",
|
992 |
+
placeholder="Enter your question...",
|
993 |
+
),
|
994 |
+
title="AI Tutor Chatbot",
|
995 |
+
description="Ask me anything about Artificial Intelligence!",
|
996 |
+
multimodal=True,
|
997 |
+
save_history=True,
|
998 |
+
)
|
999 |
+
|
1000 |
+
|
1001 |
+
if __name__ == "__main__":
|
1002 |
+
AI_Tutor.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
requests
|
3 |
+
#langchain-groq
|
4 |
+
langchain-openai
|
5 |
+
torch
|
6 |
+
vllm
|
7 |
+
accelerate
|
8 |
+
bitsandbytes
|
9 |
+
|
10 |
+
# LangChain and related dependencies
|
11 |
+
langchain
|
12 |
+
langchain-core
|
13 |
+
langchain-text-splitters
|
14 |
+
langchain-community
|
15 |
+
langgraph
|
16 |
+
chromadb
|
17 |
+
langchain-chroma
|
18 |
+
#langsmith
|
19 |
+
|
20 |
+
# Document processing
|
21 |
+
docling
|
22 |
+
langchain-docling
|
23 |
+
|
24 |
+
# Local LLM and other utilities
|
25 |
+
#llama-cpp-python
|
26 |
+
langchain_huggingface
|
27 |
+
transformers
|
28 |
+
sentence_transformers
|
29 |
+
huggingface_hub
|
30 |
+
|