Clean code and add readme
Browse files- LISA_mini.ipynb +23 -25
- README.md +31 -0
- app.py +25 -20
- documents.py +51 -130
- embeddings.py +26 -15
- llms.py +16 -34
- preprocess_documents.py +9 -4
- ragchain.py +18 -5
- requirements.txt +1 -1
- rerank.py +3 -2
- retrievers.py +12 -7
- vectorestores.py +8 -3
LISA_mini.ipynb
CHANGED
@@ -1,8 +1,16 @@
|
|
1 |
{
|
2 |
"cells": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "adcfdba2",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
@@ -18,14 +26,13 @@
|
|
18 |
"from langchain.chains import ConversationalRetrievalChain\n",
|
19 |
"from langchain.llms import HuggingFaceTextGenInference\n",
|
20 |
"from langchain.chains.conversation.memory import (\n",
|
21 |
-
" ConversationBufferMemory,\n",
|
22 |
" ConversationBufferWindowMemory,\n",
|
23 |
")"
|
24 |
]
|
25 |
},
|
26 |
{
|
27 |
"cell_type": "code",
|
28 |
-
"execution_count":
|
29 |
"id": "2d85c6d9",
|
30 |
"metadata": {},
|
31 |
"outputs": [],
|
@@ -68,7 +75,7 @@
|
|
68 |
},
|
69 |
{
|
70 |
"cell_type": "code",
|
71 |
-
"execution_count":
|
72 |
"id": "2d5bacd5",
|
73 |
"metadata": {},
|
74 |
"outputs": [],
|
@@ -107,7 +114,7 @@
|
|
107 |
},
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
-
"execution_count":
|
111 |
"id": "8cd31248",
|
112 |
"metadata": {},
|
113 |
"outputs": [],
|
@@ -140,21 +147,12 @@
|
|
140 |
},
|
141 |
{
|
142 |
"cell_type": "code",
|
143 |
-
"execution_count":
|
144 |
-
"id": "73d560de",
|
145 |
-
"metadata": {},
|
146 |
-
"outputs": [],
|
147 |
-
"source": [
|
148 |
-
"# Create retrievers"
|
149 |
-
]
|
150 |
-
},
|
151 |
-
{
|
152 |
-
"cell_type": "code",
|
153 |
-
"execution_count": 12,
|
154 |
"id": "e5796990",
|
155 |
"metadata": {},
|
156 |
"outputs": [],
|
157 |
"source": [
|
|
|
158 |
"# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n",
|
159 |
"\n",
|
160 |
"# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n",
|
@@ -178,7 +176,7 @@
|
|
178 |
},
|
179 |
{
|
180 |
"cell_type": "code",
|
181 |
-
"execution_count":
|
182 |
"id": "bc299740",
|
183 |
"metadata": {},
|
184 |
"outputs": [],
|
@@ -191,7 +189,7 @@
|
|
191 |
},
|
192 |
{
|
193 |
"cell_type": "code",
|
194 |
-
"execution_count":
|
195 |
"id": "2eb8bc8f",
|
196 |
"metadata": {},
|
197 |
"outputs": [],
|
@@ -214,7 +212,7 @@
|
|
214 |
"\n",
|
215 |
"from sentence_transformers import CrossEncoder\n",
|
216 |
"\n",
|
217 |
-
"model_name = \"BAAI/bge-reranker-large\"
|
218 |
"\n",
|
219 |
"class BgeRerank(BaseDocumentCompressor):\n",
|
220 |
" model_name:str = model_name\n",
|
@@ -273,7 +271,7 @@
|
|
273 |
},
|
274 |
{
|
275 |
"cell_type": "code",
|
276 |
-
"execution_count":
|
277 |
"id": "af780912",
|
278 |
"metadata": {},
|
279 |
"outputs": [],
|
@@ -283,7 +281,7 @@
|
|
283 |
"# Ensemble all above\n",
|
284 |
"ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n",
|
285 |
"\n",
|
286 |
-
"#
|
287 |
"compressor = BgeRerank()\n",
|
288 |
"rerank_retriever = ContextualCompressionRetriever(\n",
|
289 |
" base_compressor=compressor, base_retriever=ensemble_retriever\n",
|
@@ -292,7 +290,7 @@
|
|
292 |
},
|
293 |
{
|
294 |
"cell_type": "code",
|
295 |
-
"execution_count":
|
296 |
"id": "beb9ab21",
|
297 |
"metadata": {},
|
298 |
"outputs": [],
|
@@ -307,7 +305,7 @@
|
|
307 |
" self.return_messages = return_messages\n",
|
308 |
"\n",
|
309 |
" def create(self, retriver, llm):\n",
|
310 |
-
" memory = ConversationBufferWindowMemory(
|
311 |
" memory_key=self.memory_key,\n",
|
312 |
" return_messages=self.return_messages,\n",
|
313 |
" output_key=self.output_key,\n",
|
@@ -634,7 +632,7 @@
|
|
634 |
],
|
635 |
"metadata": {
|
636 |
"kernelspec": {
|
637 |
-
"display_name": "
|
638 |
"language": "python",
|
639 |
"name": "python3"
|
640 |
},
|
@@ -648,7 +646,7 @@
|
|
648 |
"name": "python",
|
649 |
"nbconvert_exporter": "python",
|
650 |
"pygments_lexer": "ipython3",
|
651 |
-
"version": "3.10
|
652 |
}
|
653 |
},
|
654 |
"nbformat": 4,
|
|
|
1 |
{
|
2 |
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "9267529d",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"A mini version of LISA in a Jupyter notebook for easier testing and playing around."
|
9 |
+
]
|
10 |
+
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
+
"execution_count": 2,
|
14 |
"id": "adcfdba2",
|
15 |
"metadata": {},
|
16 |
"outputs": [],
|
|
|
26 |
"from langchain.chains import ConversationalRetrievalChain\n",
|
27 |
"from langchain.llms import HuggingFaceTextGenInference\n",
|
28 |
"from langchain.chains.conversation.memory import (\n",
|
|
|
29 |
" ConversationBufferWindowMemory,\n",
|
30 |
")"
|
31 |
]
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
+
"execution_count": 3,
|
36 |
"id": "2d85c6d9",
|
37 |
"metadata": {},
|
38 |
"outputs": [],
|
|
|
75 |
},
|
76 |
{
|
77 |
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
"id": "2d5bacd5",
|
80 |
"metadata": {},
|
81 |
"outputs": [],
|
|
|
114 |
},
|
115 |
{
|
116 |
"cell_type": "code",
|
117 |
+
"execution_count": 6,
|
118 |
"id": "8cd31248",
|
119 |
"metadata": {},
|
120 |
"outputs": [],
|
|
|
147 |
},
|
148 |
{
|
149 |
"cell_type": "code",
|
150 |
+
"execution_count": 8,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
"id": "e5796990",
|
152 |
"metadata": {},
|
153 |
"outputs": [],
|
154 |
"source": [
|
155 |
+
"# Create retrievers\n",
|
156 |
"# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n",
|
157 |
"\n",
|
158 |
"# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n",
|
|
|
176 |
},
|
177 |
{
|
178 |
"cell_type": "code",
|
179 |
+
"execution_count": 9,
|
180 |
"id": "bc299740",
|
181 |
"metadata": {},
|
182 |
"outputs": [],
|
|
|
189 |
},
|
190 |
{
|
191 |
"cell_type": "code",
|
192 |
+
"execution_count": 10,
|
193 |
"id": "2eb8bc8f",
|
194 |
"metadata": {},
|
195 |
"outputs": [],
|
|
|
212 |
"\n",
|
213 |
"from sentence_transformers import CrossEncoder\n",
|
214 |
"\n",
|
215 |
+
"model_name = \"BAAI/bge-reranker-large\"\n",
|
216 |
"\n",
|
217 |
"class BgeRerank(BaseDocumentCompressor):\n",
|
218 |
" model_name:str = model_name\n",
|
|
|
271 |
},
|
272 |
{
|
273 |
"cell_type": "code",
|
274 |
+
"execution_count": 11,
|
275 |
"id": "af780912",
|
276 |
"metadata": {},
|
277 |
"outputs": [],
|
|
|
281 |
"# Ensemble all above\n",
|
282 |
"ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n",
|
283 |
"\n",
|
284 |
+
"# Rerank\n",
|
285 |
"compressor = BgeRerank()\n",
|
286 |
"rerank_retriever = ContextualCompressionRetriever(\n",
|
287 |
" base_compressor=compressor, base_retriever=ensemble_retriever\n",
|
|
|
290 |
},
|
291 |
{
|
292 |
"cell_type": "code",
|
293 |
+
"execution_count": 12,
|
294 |
"id": "beb9ab21",
|
295 |
"metadata": {},
|
296 |
"outputs": [],
|
|
|
305 |
" self.return_messages = return_messages\n",
|
306 |
"\n",
|
307 |
" def create(self, retriver, llm):\n",
|
308 |
+
" memory = ConversationBufferWindowMemory(\n",
|
309 |
" memory_key=self.memory_key,\n",
|
310 |
" return_messages=self.return_messages,\n",
|
311 |
" output_key=self.output_key,\n",
|
|
|
632 |
],
|
633 |
"metadata": {
|
634 |
"kernelspec": {
|
635 |
+
"display_name": "lisa",
|
636 |
"language": "python",
|
637 |
"name": "python3"
|
638 |
},
|
|
|
646 |
"name": "python",
|
647 |
"nbconvert_exporter": "python",
|
648 |
"pygments_lexer": "ipython3",
|
649 |
+
"version": "3.11.10"
|
650 |
}
|
651 |
},
|
652 |
"nbformat": 4,
|
README.md
CHANGED
@@ -11,3 +11,34 @@ startup_duration_timeout: 2h
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
15 |
+
LISA (Lithium Ion Solid-state Assistant) is a question-and-answer (Q&A) research assistant designed for efficient knowledge management with a primary focus on battery science, yet versatile enough to support broader scientific domains. Built on a Retrieval-Augmented Generation (RAG) architecture, LISA uses advanced Large Language Models (LLMs) to provide reliable, detailed answers to research questions.
|
16 |
+
|
17 |
+
DEMO: https://huggingface.co/spaces/Kadi-IAM/LISA
|
18 |
+
|
19 |
+
### Installation
|
20 |
+
1. Clone the Repository:
|
21 |
+
```bash
|
22 |
+
git clone "link of this repo"
|
23 |
+
cd LISA
|
24 |
+
```
|
25 |
+
|
26 |
+
2. Install Dependencies:
|
27 |
+
```bash
|
28 |
+
pip install -r requirements.txt
|
29 |
+
```
|
30 |
+
|
31 |
+
3. Set Up the Knowledge Base
|
32 |
+
Populate the knowledge base with relevant documents or research papers. Ensure that documents are in a format (pdf or xml) compatible with the RAG pipeline. By default documents should be located at `data/documents`. After running the following comand, some caches files are saved into `data/db`. ATTENTION: pickle is used to save these caches, be careful with potential security risks.
|
33 |
+
```bash
|
34 |
+
python preprocess_documents.py
|
35 |
+
```
|
36 |
+
|
37 |
+
4. Running LISA
|
38 |
+
Once setup is complete, run the following command to launch LISA:
|
39 |
+
```bash
|
40 |
+
python app.py
|
41 |
+
```
|
42 |
+
|
43 |
+
### About
|
44 |
+
For more information on our work in intelligent research data management systems, please visit [KadiAI](https://kadi.iam.kit.edu/kadi-ai).
|
app.py
CHANGED
@@ -1,12 +1,15 @@
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import time
|
3 |
import re
|
4 |
-
|
5 |
-
from dotenv import load_dotenv
|
6 |
import pickle
|
7 |
|
8 |
-
|
9 |
-
|
10 |
|
11 |
from huggingface_hub import login
|
12 |
from langchain.vectorstores import FAISS
|
@@ -15,24 +18,21 @@ from llms import get_groq_chat
|
|
15 |
from documents import load_pdf_as_docs, load_xml_as_docs
|
16 |
from vectorestores import get_faiss_vectorestore
|
17 |
|
18 |
-
|
19 |
# For debug
|
20 |
# from langchain.globals import set_debug
|
21 |
# set_debug(True)
|
22 |
|
23 |
-
|
24 |
# Load and set env variables
|
25 |
load_dotenv()
|
26 |
|
|
|
27 |
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
28 |
login(HUGGINGFACEHUB_API_TOKEN)
|
29 |
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] # Search engine
|
30 |
|
31 |
-
# Other settings
|
32 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
33 |
|
34 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
35 |
-
|
36 |
database_root = "./data/db"
|
37 |
document_path = "./data/documents"
|
38 |
|
@@ -80,12 +80,13 @@ from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
|
80 |
|
81 |
bm25_retriever = BM25Retriever.from_documents(
|
82 |
document_chunks, k=5
|
83 |
-
) # 1/2 of dense retriever, experimental value
|
84 |
|
85 |
-
# Ensemble all above
|
86 |
ensemble_retriever = EnsembleRetriever(
|
87 |
retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5]
|
88 |
)
|
|
|
89 |
# Reranker
|
90 |
from rerank import BgeRerank
|
91 |
|
@@ -98,7 +99,7 @@ print("rerank loaded")
|
|
98 |
llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
|
99 |
|
100 |
|
101 |
-
#
|
102 |
from ragchain import RAGChain
|
103 |
|
104 |
rag_chain = RAGChain()
|
@@ -108,13 +109,11 @@ lisa_qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True
|
|
108 |
from langchain_community.retrievers import TavilySearchAPIRetriever
|
109 |
from langchain.chains import RetrievalQAWithSourcesChain
|
110 |
|
111 |
-
web_search_retriever = TavilySearchAPIRetriever(
|
112 |
-
k=4
|
113 |
-
) # , include_raw_content=True)#, include_raw_content=True)
|
114 |
web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
|
115 |
llm, retriever=web_search_retriever, return_source_documents=True
|
116 |
)
|
117 |
-
print("
|
118 |
|
119 |
|
120 |
# Gradio utils
|
@@ -136,7 +135,7 @@ def add_text(history, text):
|
|
136 |
|
137 |
|
138 |
def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6):
|
139 |
-
"""
|
140 |
|
141 |
# Remove trailing references at end of text
|
142 |
if "References:\n[" in text:
|
@@ -480,7 +479,7 @@ def main():
|
|
480 |
# flag_web_search = gr.Checkbox(label="Search web", info="Search information from Internet")
|
481 |
gr.Markdown("More in DEV...")
|
482 |
|
483 |
-
#
|
484 |
user_txt.submit(check_input_text, user_txt, None).success(
|
485 |
add_text, [chatbot, user_txt], [chatbot, user_txt]
|
486 |
).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation])
|
@@ -575,6 +574,7 @@ def main():
|
|
575 |
with gr.Tab("Setting"):
|
576 |
gr.Markdown("More in DEV...")
|
577 |
|
|
|
578 |
load_document.click(
|
579 |
document_changes,
|
580 |
inputs=[uploaded_doc], # , repo_id],
|
@@ -606,8 +606,9 @@ def main():
|
|
606 |
)
|
607 |
|
608 |
##########################
|
609 |
-
# Preview
|
610 |
with gr.Tab("Preview feature 🔬"):
|
|
|
611 |
with gr.Tab("Vision LM 🖼"):
|
612 |
vision_tmp_link = (
|
613 |
"https://kadi-iam-lisa-vlm.hf.space/" # vision model link
|
@@ -620,6 +621,7 @@ def main():
|
|
620 |
)
|
621 |
# gr.Markdown("placeholder")
|
622 |
|
|
|
623 |
with gr.Tab("KadiChat 💬"):
|
624 |
kadichat_tmp_link = (
|
625 |
"https://kadi-iam-kadichat.hf.space/" # vision model link
|
@@ -631,9 +633,12 @@ def main():
|
|
631 |
)
|
632 |
)
|
633 |
|
|
|
634 |
with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"):
|
635 |
kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html"
|
636 |
-
gr.Markdown(
|
|
|
|
|
637 |
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab:
|
638 |
gr.HTML(
|
639 |
"""<iframe
|
|
|
1 |
+
"""
|
2 |
+
Main app for LISA RAG chatbot based on langchain.
|
3 |
+
"""
|
4 |
+
|
5 |
import os
|
6 |
import time
|
7 |
import re
|
8 |
+
import gradio as gr
|
|
|
9 |
import pickle
|
10 |
|
11 |
+
from pathlib import Path
|
12 |
+
from dotenv import load_dotenv
|
13 |
|
14 |
from huggingface_hub import login
|
15 |
from langchain.vectorstores import FAISS
|
|
|
18 |
from documents import load_pdf_as_docs, load_xml_as_docs
|
19 |
from vectorestores import get_faiss_vectorestore
|
20 |
|
|
|
21 |
# For debug
|
22 |
# from langchain.globals import set_debug
|
23 |
# set_debug(True)
|
24 |
|
|
|
25 |
# Load and set env variables
|
26 |
load_dotenv()
|
27 |
|
28 |
+
# Set API keys
|
29 |
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
30 |
login(HUGGINGFACEHUB_API_TOKEN)
|
31 |
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] # Search engine
|
32 |
|
|
|
|
|
33 |
|
34 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
35 |
+
# Set database path
|
36 |
database_root = "./data/db"
|
37 |
document_path = "./data/documents"
|
38 |
|
|
|
80 |
|
81 |
bm25_retriever = BM25Retriever.from_documents(
|
82 |
document_chunks, k=5
|
83 |
+
) # k = 1/2 of dense retriever, experimental value
|
84 |
|
85 |
+
# Ensemble all above retrievers
|
86 |
ensemble_retriever = EnsembleRetriever(
|
87 |
retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5]
|
88 |
)
|
89 |
+
|
90 |
# Reranker
|
91 |
from rerank import BgeRerank
|
92 |
|
|
|
99 |
llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
|
100 |
|
101 |
|
102 |
+
# Create conversation qa chain (Note: conversation is not supported yet)
|
103 |
from ragchain import RAGChain
|
104 |
|
105 |
rag_chain = RAGChain()
|
|
|
109 |
from langchain_community.retrievers import TavilySearchAPIRetriever
|
110 |
from langchain.chains import RetrievalQAWithSourcesChain
|
111 |
|
112 |
+
web_search_retriever = TavilySearchAPIRetriever(k=4) # , include_raw_content=True)
|
|
|
|
|
113 |
web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
|
114 |
llm, retriever=web_search_retriever, return_source_documents=True
|
115 |
)
|
116 |
+
print("chains loaded")
|
117 |
|
118 |
|
119 |
# Gradio utils
|
|
|
135 |
|
136 |
|
137 |
def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6):
|
138 |
+
"""Heuristic removal of misinfo. of citations."""
|
139 |
|
140 |
# Remove trailing references at end of text
|
141 |
if "References:\n[" in text:
|
|
|
479 |
# flag_web_search = gr.Checkbox(label="Search web", info="Search information from Internet")
|
480 |
gr.Markdown("More in DEV...")
|
481 |
|
482 |
+
# Action functions
|
483 |
user_txt.submit(check_input_text, user_txt, None).success(
|
484 |
add_text, [chatbot, user_txt], [chatbot, user_txt]
|
485 |
).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation])
|
|
|
574 |
with gr.Tab("Setting"):
|
575 |
gr.Markdown("More in DEV...")
|
576 |
|
577 |
+
# Actions
|
578 |
load_document.click(
|
579 |
document_changes,
|
580 |
inputs=[uploaded_doc], # , repo_id],
|
|
|
606 |
)
|
607 |
|
608 |
##########################
|
609 |
+
# Preview tabs
|
610 |
with gr.Tab("Preview feature 🔬"):
|
611 |
+
# VLM model
|
612 |
with gr.Tab("Vision LM 🖼"):
|
613 |
vision_tmp_link = (
|
614 |
"https://kadi-iam-lisa-vlm.hf.space/" # vision model link
|
|
|
621 |
)
|
622 |
# gr.Markdown("placeholder")
|
623 |
|
624 |
+
# OAuth2 linkage to Kadi-demo
|
625 |
with gr.Tab("KadiChat 💬"):
|
626 |
kadichat_tmp_link = (
|
627 |
"https://kadi-iam-kadichat.hf.space/" # vision model link
|
|
|
633 |
)
|
634 |
)
|
635 |
|
636 |
+
# Knowledge graph-enhanced RAG
|
637 |
with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"):
|
638 |
kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html"
|
639 |
+
gr.Markdown(
|
640 |
+
"[If rendering fails, look at the graph here](https://kadi-iam-kadikgraph.static.hf.space)"
|
641 |
+
)
|
642 |
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab:
|
643 |
gr.HTML(
|
644 |
"""<iframe
|
documents.py
CHANGED
@@ -1,25 +1,30 @@
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
-
import shutil
|
3 |
|
4 |
from langchain.document_loaders import (
|
5 |
PyMuPDFLoader,
|
6 |
)
|
7 |
from langchain.docstore.document import Document
|
8 |
-
|
9 |
-
from langchain.vectorstores import Chroma
|
10 |
-
|
11 |
from langchain.text_splitter import (
|
12 |
-
RecursiveCharacterTextSplitter,
|
13 |
SpacyTextSplitter,
|
14 |
)
|
15 |
|
|
|
16 |
def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
|
17 |
"""Load and parse pdf file(s)."""
|
18 |
-
|
19 |
-
if pdf_path.endswith(
|
20 |
pdf_docs = [pdf_path]
|
21 |
else: # a directory
|
22 |
-
pdf_docs = [
|
|
|
|
|
|
|
|
|
23 |
|
24 |
if load_kwargs is None:
|
25 |
load_kwargs = {}
|
@@ -31,180 +36,96 @@ def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
|
|
31 |
loader = loader_module(pdf, **load_kwargs)
|
32 |
doc = loader.load()
|
33 |
docs.extend(doc)
|
34 |
-
|
35 |
return docs
|
36 |
|
|
|
37 |
def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None):
|
38 |
"""Load and parse xml file(s)."""
|
39 |
-
|
40 |
from bs4 import BeautifulSoup
|
41 |
from unstructured.cleaners.core import group_broken_paragraphs
|
42 |
-
|
43 |
-
if xml_path.endswith(
|
44 |
xml_docs = [xml_path]
|
45 |
else: # a directory
|
46 |
-
xml_docs = [
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
if load_kwargs is None:
|
49 |
load_kwargs = {}
|
50 |
|
51 |
docs = []
|
52 |
for xml_file in xml_docs:
|
53 |
-
# print("now reading file...")
|
54 |
with open(xml_file) as fp:
|
55 |
-
soup = BeautifulSoup(
|
|
|
|
|
56 |
pageText = soup.findAll(string=True)
|
57 |
-
parsed_text =
|
58 |
-
#
|
59 |
parsed_text_grouped = group_broken_paragraphs(parsed_text)
|
60 |
-
|
61 |
# get metadata
|
62 |
try:
|
63 |
from lxml import etree as ET
|
|
|
64 |
tree = ET.parse(xml_file)
|
65 |
|
66 |
# Define namespace
|
67 |
ns = {"tei": "http://www.tei-c.org/ns/1.0"}
|
68 |
# Read Author personal names as an example
|
69 |
-
pers_name_elements = tree.xpath(
|
|
|
|
|
|
|
70 |
first_per = pers_name_elements[0].text
|
71 |
author_info = first_per + " et al"
|
72 |
|
73 |
-
title_elements = tree.xpath(
|
|
|
|
|
74 |
title = title_elements[0].text
|
75 |
|
76 |
# Combine source info
|
77 |
source_info = "_".join([author_info, title])
|
78 |
except:
|
79 |
source_info = "unknown"
|
80 |
-
|
81 |
-
# maybe even better TODO: discuss with
|
82 |
# first_author = soup.find("author")
|
83 |
# publication_year = soup.find("date", attrs={'type': 'published'})
|
84 |
# title = soup.find("title")
|
85 |
# source_info = [first_author, publication_year, title]
|
86 |
# source_info_str = "_".join([info.text.strip() if info is not None else "unknown" for info in source_info])
|
87 |
-
|
88 |
-
doc =
|
|
|
|
|
|
|
|
|
89 |
|
90 |
docs.extend(doc)
|
91 |
-
|
92 |
return docs
|
93 |
|
94 |
|
95 |
def get_doc_chunks(docs, splitter=None):
|
96 |
"""Split docs into chunks."""
|
97 |
-
|
98 |
if splitter is None:
|
99 |
-
# splitter = RecursiveCharacterTextSplitter(
|
100 |
# # separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256
|
101 |
# separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128
|
102 |
# )
|
|
|
103 |
splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
104 |
chunk_size=512,
|
105 |
chunk_overlap=128,
|
106 |
)
|
107 |
chunks = splitter.split_documents(docs)
|
108 |
-
|
109 |
-
return chunks
|
110 |
-
|
111 |
|
112 |
-
|
113 |
-
# embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
114 |
-
# vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
|
115 |
-
if overwrite:
|
116 |
-
shutil.rmtree(persist_directory) # Empty and reset db
|
117 |
-
db = Chroma.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory)
|
118 |
-
# db.delete_collection()
|
119 |
-
db.persist()
|
120 |
-
# db = None
|
121 |
-
# db = Chroma(persist_directory="db", embedding_function = embeddings, client_settings=CHROMA_SETTINGS)
|
122 |
-
# vectorstore = FAISS.from_documents(documents=document_chunks, embedding=embeddings)
|
123 |
-
return db
|
124 |
-
|
125 |
-
|
126 |
-
class VectorstoreManager:
|
127 |
-
|
128 |
-
def __init__(self):
|
129 |
-
self.vectorstore_class = Chroma
|
130 |
-
|
131 |
-
def create_db(self, embeddings):
|
132 |
-
db = self.vectorstore_class(embedding_function=embeddings)
|
133 |
-
|
134 |
-
self.db = db
|
135 |
-
return db
|
136 |
-
|
137 |
-
|
138 |
-
def load_db(self, persist_directory, embeddings):
|
139 |
-
"""Load local vectorestore."""
|
140 |
-
|
141 |
-
db = self.vectorstore_class(persist_directory=persist_directory, embedding_function=embeddings)
|
142 |
-
self.db = db
|
143 |
-
|
144 |
-
return db
|
145 |
-
|
146 |
-
def create_db_from_documents(self, document_chunks, embeddings, persist_directory="db", overwrite=False):
|
147 |
-
"""Create db from documents."""
|
148 |
-
|
149 |
-
if overwrite:
|
150 |
-
shutil.rmtree(persist_directory) # Empty and reset db
|
151 |
-
db = self.vectorstore_class.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory)
|
152 |
-
self.db = db
|
153 |
-
|
154 |
-
return db
|
155 |
-
|
156 |
-
def persist_db(self, persist_directory="db"):
|
157 |
-
"""Persist db."""
|
158 |
-
|
159 |
-
assert self.db
|
160 |
-
self.db.persist() # Chroma
|
161 |
-
|
162 |
-
class RetrieverManager:
|
163 |
-
# some other retrievers Using Advanced Retrievers in LangChain https://www.comet.com/site/blog/using-advanced-retrievers-in-langchain/
|
164 |
-
|
165 |
-
def __init__(self, vectorstore, k=10):
|
166 |
-
|
167 |
-
self.vectorstore = vectorstore
|
168 |
-
self.retriever = vectorstore.as_retriever(search_kwargs={"k": k}) #search_kwargs={"k": 8}),
|
169 |
-
|
170 |
-
def get_rerank_retriver(self, base_retriever=None):
|
171 |
-
|
172 |
-
if base_retriever is None:
|
173 |
-
base_retriever = self.retriever
|
174 |
-
# with rerank
|
175 |
-
from rerank import BgeRerank
|
176 |
-
from langchain.retrievers import ContextualCompressionRetriever
|
177 |
-
|
178 |
-
compressor = BgeRerank()
|
179 |
-
compression_retriever = ContextualCompressionRetriever(
|
180 |
-
base_compressor=compressor, base_retriever=base_retriever
|
181 |
-
)
|
182 |
-
|
183 |
-
return compression_retriever
|
184 |
-
|
185 |
-
def get_parent_doc_retriver(self, documents, store_file="./store_location"):
|
186 |
-
# TODO need better design
|
187 |
-
# Ref: explain how it works: https://clusteredbytes.pages.dev/posts/2023/langchain-parent-document-retriever/
|
188 |
-
from langchain.storage.file_system import LocalFileStore
|
189 |
-
from langchain.storage import InMemoryStore
|
190 |
-
from langchain.storage._lc_store import create_kv_docstore
|
191 |
-
from langchain.retrievers import ParentDocumentRetriever
|
192 |
-
# Ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain
|
193 |
-
# fs = LocalFileStore("./store_location")
|
194 |
-
# store = create_kv_docstore(fs)
|
195 |
-
docstore = InMemoryStore()
|
196 |
-
|
197 |
-
# TODO: how to better set this?
|
198 |
-
parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
|
199 |
-
child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128)
|
200 |
-
|
201 |
-
retriever = ParentDocumentRetriever(
|
202 |
-
vectorstore=self.vectorstore,
|
203 |
-
docstore=docstore,
|
204 |
-
child_splitter=child_splitter,
|
205 |
-
parent_splitter=parent_splitter,
|
206 |
-
search_kwargs={"k":10} # Better settings?
|
207 |
-
)
|
208 |
-
retriever.add_documents(documents)#, ids=None)
|
209 |
-
|
210 |
-
return retriever
|
|
|
1 |
+
"""
|
2 |
+
Parse documents, currently pdf and xml are supported.
|
3 |
+
"""
|
4 |
+
|
5 |
import os
|
|
|
6 |
|
7 |
from langchain.document_loaders import (
|
8 |
PyMuPDFLoader,
|
9 |
)
|
10 |
from langchain.docstore.document import Document
|
|
|
|
|
|
|
11 |
from langchain.text_splitter import (
|
12 |
+
# RecursiveCharacterTextSplitter,
|
13 |
SpacyTextSplitter,
|
14 |
)
|
15 |
|
16 |
+
|
17 |
def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
|
18 |
"""Load and parse pdf file(s)."""
|
19 |
+
|
20 |
+
if pdf_path.endswith(".pdf"): # single file
|
21 |
pdf_docs = [pdf_path]
|
22 |
else: # a directory
|
23 |
+
pdf_docs = [
|
24 |
+
os.path.join(pdf_path, f)
|
25 |
+
for f in os.listdir(pdf_path)
|
26 |
+
if f.endswith(".pdf")
|
27 |
+
]
|
28 |
|
29 |
if load_kwargs is None:
|
30 |
load_kwargs = {}
|
|
|
36 |
loader = loader_module(pdf, **load_kwargs)
|
37 |
doc = loader.load()
|
38 |
docs.extend(doc)
|
39 |
+
|
40 |
return docs
|
41 |
|
42 |
+
|
43 |
def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None):
|
44 |
"""Load and parse xml file(s)."""
|
45 |
+
|
46 |
from bs4 import BeautifulSoup
|
47 |
from unstructured.cleaners.core import group_broken_paragraphs
|
48 |
+
|
49 |
+
if xml_path.endswith(".xml"): # single file
|
50 |
xml_docs = [xml_path]
|
51 |
else: # a directory
|
52 |
+
xml_docs = [
|
53 |
+
os.path.join(xml_path, f)
|
54 |
+
for f in os.listdir(xml_path)
|
55 |
+
if f.endswith(".xml")
|
56 |
+
]
|
57 |
+
|
58 |
if load_kwargs is None:
|
59 |
load_kwargs = {}
|
60 |
|
61 |
docs = []
|
62 |
for xml_file in xml_docs:
|
|
|
63 |
with open(xml_file) as fp:
|
64 |
+
soup = BeautifulSoup(
|
65 |
+
fp, features="xml"
|
66 |
+
) # txt is simply the a string with your XML file
|
67 |
pageText = soup.findAll(string=True)
|
68 |
+
parsed_text = "\n".join(pageText) # or " ".join, seems similar
|
69 |
+
# Clean text
|
70 |
parsed_text_grouped = group_broken_paragraphs(parsed_text)
|
71 |
+
|
72 |
# get metadata
|
73 |
try:
|
74 |
from lxml import etree as ET
|
75 |
+
|
76 |
tree = ET.parse(xml_file)
|
77 |
|
78 |
# Define namespace
|
79 |
ns = {"tei": "http://www.tei-c.org/ns/1.0"}
|
80 |
# Read Author personal names as an example
|
81 |
+
pers_name_elements = tree.xpath(
|
82 |
+
"tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:author/tei:persName",
|
83 |
+
namespaces=ns,
|
84 |
+
)
|
85 |
first_per = pers_name_elements[0].text
|
86 |
author_info = first_per + " et al"
|
87 |
|
88 |
+
title_elements = tree.xpath(
|
89 |
+
"tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:title", namespaces=ns
|
90 |
+
)
|
91 |
title = title_elements[0].text
|
92 |
|
93 |
# Combine source info
|
94 |
source_info = "_".join([author_info, title])
|
95 |
except:
|
96 |
source_info = "unknown"
|
97 |
+
|
98 |
+
# maybe even better parsing method. TODO: discuss with TUD
|
99 |
# first_author = soup.find("author")
|
100 |
# publication_year = soup.find("date", attrs={'type': 'published'})
|
101 |
# title = soup.find("title")
|
102 |
# source_info = [first_author, publication_year, title]
|
103 |
# source_info_str = "_".join([info.text.strip() if info is not None else "unknown" for info in source_info])
|
104 |
+
|
105 |
+
doc = [
|
106 |
+
Document(
|
107 |
+
page_content=parsed_text_grouped, metadata={"source": source_info}
|
108 |
+
)
|
109 |
+
]
|
110 |
|
111 |
docs.extend(doc)
|
112 |
+
|
113 |
return docs
|
114 |
|
115 |
|
116 |
def get_doc_chunks(docs, splitter=None):
|
117 |
"""Split docs into chunks."""
|
118 |
+
|
119 |
if splitter is None:
|
120 |
+
# splitter = RecursiveCharacterTextSplitter( # original default
|
121 |
# # separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256
|
122 |
# separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128
|
123 |
# )
|
124 |
+
# Spacy seems better
|
125 |
splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
126 |
chunk_size=512,
|
127 |
chunk_overlap=128,
|
128 |
)
|
129 |
chunks = splitter.split_documents(docs)
|
|
|
|
|
|
|
130 |
|
131 |
+
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings.py
CHANGED
@@ -1,39 +1,50 @@
|
|
|
|
|
|
|
|
1 |
|
2 |
import torch
|
3 |
from langchain.embeddings import HuggingFaceEmbeddings
|
4 |
|
5 |
|
6 |
def get_hf_embeddings(model_name=None):
|
7 |
-
"""Get huggingface embedding."""
|
8 |
-
|
9 |
if model_name is None:
|
10 |
-
# Some candiates
|
11 |
# "BAAI/bge-m3" (good, though large and slow)
|
12 |
-
# "BAAI/bge-base-en-v1.5" ->
|
13 |
-
# "sentence-transformers/all-mpnet-base-v2"
|
14 |
-
#
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
17 |
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
18 |
-
|
19 |
return embeddings
|
20 |
|
21 |
|
22 |
-
def get_jinaai_embeddings(
|
|
|
|
|
23 |
"""Get jinaai embedding."""
|
24 |
-
|
25 |
# device: cpu or cuda
|
26 |
if device == "auto":
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
# For jinaai. Ref: https://github.com/langchain-ai/langchain/issues/6080
|
29 |
from transformers import AutoModel
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
model_name = model_name
|
33 |
-
model_kwargs = {
|
34 |
embeddings = HuggingFaceEmbeddings(
|
35 |
model_name=model_name,
|
36 |
model_kwargs=model_kwargs,
|
37 |
)
|
38 |
-
|
39 |
-
return embeddings
|
|
|
1 |
+
"""
|
2 |
+
Load embedding models from huggingface.
|
3 |
+
"""
|
4 |
|
5 |
import torch
|
6 |
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
|
8 |
|
9 |
def get_hf_embeddings(model_name=None):
|
10 |
+
"""Get huggingface embedding by name."""
|
11 |
+
|
12 |
if model_name is None:
|
13 |
+
# Some candiates
|
14 |
# "BAAI/bge-m3" (good, though large and slow)
|
15 |
+
# "BAAI/bge-base-en-v1.5" -> also good
|
16 |
+
# "sentence-transformers/all-mpnet-base-v2"
|
17 |
+
# "maidalun1020/bce-embedding-base_v1"
|
18 |
+
# "intfloat/multilingual-e5-large"
|
19 |
+
# Ref: https://huggingface.co/spaces/mteb/leaderboard
|
20 |
+
# https://huggingface.co/maidalun1020/bce-embedding-base_v1
|
21 |
+
model_name = "BAAI/bge-large-en-v1.5"
|
22 |
+
|
23 |
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
24 |
+
|
25 |
return embeddings
|
26 |
|
27 |
|
28 |
+
def get_jinaai_embeddings(
|
29 |
+
model_name="jinaai/jina-embeddings-v2-base-en", device="auto"
|
30 |
+
):
|
31 |
"""Get jinaai embedding."""
|
32 |
+
|
33 |
# device: cpu or cuda
|
34 |
if device == "auto":
|
35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
# For jinaai. Ref: https://github.com/langchain-ai/langchain/issues/6080
|
37 |
from transformers import AutoModel
|
38 |
+
|
39 |
+
model = AutoModel.from_pretrained(
|
40 |
+
model_name, trust_remote_code=True
|
41 |
+
) # -> will yield error, need bug fixing
|
42 |
|
43 |
model_name = model_name
|
44 |
+
model_kwargs = {"device": device, "trust_remote_code": True}
|
45 |
embeddings = HuggingFaceEmbeddings(
|
46 |
model_name=model_name,
|
47 |
model_kwargs=model_kwargs,
|
48 |
)
|
49 |
+
|
50 |
+
return embeddings
|
llms.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
-
|
2 |
-
from
|
|
|
|
|
3 |
from transformers import (
|
4 |
-
AutoModelForCausalLM,
|
5 |
AutoTokenizer,
|
6 |
pipeline,
|
7 |
)
|
8 |
-
from
|
9 |
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
10 |
from langchain_groq import ChatGroq
|
11 |
-
|
12 |
-
|
13 |
-
from langchain.chat_models import ChatOpenAI
|
14 |
from langchain.llms import HuggingFaceTextGenInference
|
15 |
|
|
|
|
|
16 |
|
17 |
def get_llm_hf_online(inference_api_url=""):
|
18 |
"""Get LLM using huggingface inference."""
|
19 |
-
|
20 |
if not inference_api_url: # default api url
|
21 |
inference_api_url = (
|
22 |
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
@@ -35,20 +35,16 @@ def get_llm_hf_online(inference_api_url=""):
|
|
35 |
|
36 |
|
37 |
def get_llm_hf_local(model_path):
|
38 |
-
"""Get local LLM."""
|
39 |
-
|
40 |
-
model = LlamaForCausalLM.from_pretrained(
|
41 |
-
model_path, device_map="auto"
|
42 |
-
)
|
43 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
44 |
|
45 |
-
# print('making a pipeline...')
|
46 |
-
# max_length has typically been deprecated for max_new_tokens
|
47 |
pipe = pipeline(
|
48 |
"text-generation",
|
49 |
model=model,
|
50 |
tokenizer=tokenizer,
|
51 |
-
max_new_tokens=
|
52 |
model_kwargs={"temperature": 0.1}, # better setting?
|
53 |
)
|
54 |
llm = HuggingFacePipeline(pipeline=pipe)
|
@@ -56,22 +52,8 @@ def get_llm_hf_local(model_path):
|
|
56 |
return llm
|
57 |
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
"""Get openai-like LLM."""
|
62 |
-
|
63 |
-
llm = ChatOpenAI(
|
64 |
-
model=model_name,
|
65 |
-
openai_api_key="EMPTY",
|
66 |
-
openai_api_base=inference_server_url,
|
67 |
-
max_tokens=1024, # better setting?
|
68 |
-
temperature=0,
|
69 |
-
)
|
70 |
-
|
71 |
-
return llm
|
72 |
-
|
73 |
-
|
74 |
-
def get_groq_chat(model_name="llama-3.1-70b-versatile"):
|
75 |
|
76 |
llm = ChatGroq(temperature=0, model_name=model_name)
|
77 |
-
return llm
|
|
|
1 |
+
"""
|
2 |
+
Load LLMs from huggingface, Groq, etc.
|
3 |
+
"""
|
4 |
+
|
5 |
from transformers import (
|
6 |
+
# AutoModelForCausalLM,
|
7 |
AutoTokenizer,
|
8 |
pipeline,
|
9 |
)
|
10 |
+
from langchain.llms import HuggingFacePipeline
|
|
|
11 |
from langchain_groq import ChatGroq
|
|
|
|
|
|
|
12 |
from langchain.llms import HuggingFaceTextGenInference
|
13 |
|
14 |
+
# from langchain.chat_models import ChatOpenAI # oai model
|
15 |
+
|
16 |
|
17 |
def get_llm_hf_online(inference_api_url=""):
|
18 |
"""Get LLM using huggingface inference."""
|
19 |
+
|
20 |
if not inference_api_url: # default api url
|
21 |
inference_api_url = (
|
22 |
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
|
|
35 |
|
36 |
|
37 |
def get_llm_hf_local(model_path):
|
38 |
+
"""Get local LLM from huggingface."""
|
39 |
+
|
40 |
+
model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto")
|
|
|
|
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
42 |
|
|
|
|
|
43 |
pipe = pipeline(
|
44 |
"text-generation",
|
45 |
model=model,
|
46 |
tokenizer=tokenizer,
|
47 |
+
max_new_tokens=2048, # better setting?
|
48 |
model_kwargs={"temperature": 0.1}, # better setting?
|
49 |
)
|
50 |
llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
52 |
return llm
|
53 |
|
54 |
|
55 |
+
def get_groq_chat(model_name="llama-3.1-70b-versatile"):
|
56 |
+
"""Get LLM from Groq."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
llm = ChatGroq(temperature=0, model_name=model_name)
|
59 |
+
return llm
|
preprocess_documents.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
"""
|
2 |
-
Load and parse files (pdf) in the data/documents and save cached pkl files.
|
|
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
6 |
import pickle
|
7 |
|
8 |
from dotenv import load_dotenv
|
9 |
-
|
10 |
-
|
11 |
from huggingface_hub import login
|
12 |
-
|
13 |
from documents import load_pdf_as_docs, get_doc_chunks
|
14 |
from embeddings import get_jinaai_embeddings
|
15 |
|
@@ -23,11 +25,14 @@ login(HUGGINGFACEHUB_API_TOKEN)
|
|
23 |
|
24 |
|
25 |
def save_to_pickle(obj, filename):
|
|
|
|
|
26 |
with open(filename, "wb") as file:
|
27 |
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
|
28 |
|
29 |
|
30 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
|
31 |
database_root = "./data/db"
|
32 |
document_path = "./data/documents"
|
33 |
|
|
|
1 |
"""
|
2 |
+
Load and parse files (pdf) in the "data/documents" and save cached pkl files.
|
3 |
+
It will load and parse files and save 4 caches:
|
4 |
+
1. "docs.pkl" for loaded text documents
|
5 |
+
2. "docs_chunks.pkl" for chunked text
|
6 |
+
3. "docstore.pkl" for small-to-big retriever
|
7 |
+
4. faiss_index for FAISS vectore store
|
8 |
"""
|
9 |
|
10 |
import os
|
11 |
import pickle
|
12 |
|
13 |
from dotenv import load_dotenv
|
|
|
|
|
14 |
from huggingface_hub import login
|
|
|
15 |
from documents import load_pdf_as_docs, get_doc_chunks
|
16 |
from embeddings import get_jinaai_embeddings
|
17 |
|
|
|
25 |
|
26 |
|
27 |
def save_to_pickle(obj, filename):
|
28 |
+
"""Save obj to disk using pickle."""
|
29 |
+
|
30 |
with open(filename, "wb") as file:
|
31 |
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
|
32 |
|
33 |
|
34 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
35 |
+
# Set database path, should be same as defined in "app.py"
|
36 |
database_root = "./data/db"
|
37 |
document_path = "./data/documents"
|
38 |
|
ragchain.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
from langchain.chains import LLMChain
|
2 |
|
3 |
from langchain.prompts import (
|
@@ -11,17 +15,17 @@ from langchain.chains import ConversationalRetrievalChain
|
|
11 |
from langchain.chains.conversation.memory import (
|
12 |
ConversationBufferWindowMemory,
|
13 |
)
|
14 |
-
|
15 |
-
|
16 |
from langchain.chains import StuffDocumentsChain
|
17 |
|
18 |
|
19 |
def get_cite_combine_docs_chain(llm):
|
|
|
20 |
|
21 |
# Ref: https://github.com/langchain-ai/langchain/issues/7239
|
22 |
# Function to format each document with an index, source, and content.
|
23 |
def format_document(doc, index, prompt):
|
24 |
"""Format a document into a string based on a prompt template."""
|
|
|
25 |
# Create a dictionary with document content and metadata.
|
26 |
base_info = {
|
27 |
"page_content": doc.page_content,
|
@@ -40,7 +44,11 @@ def get_cite_combine_docs_chain(llm):
|
|
40 |
|
41 |
# Custom chain class to handle document combination with source indices.
|
42 |
class StuffDocumentsWithIndexChain(StuffDocumentsChain):
|
|
|
|
|
43 |
def _get_inputs(self, docs, **kwargs):
|
|
|
|
|
44 |
# Format each document and combine them.
|
45 |
doc_strings = [
|
46 |
format_document(doc, i, self.document_prompt)
|
@@ -58,6 +66,7 @@ def get_cite_combine_docs_chain(llm):
|
|
58 |
)
|
59 |
return inputs
|
60 |
|
|
|
61 |
# Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
|
62 |
# Define a chat prompt with instructions for citing documents.
|
63 |
combine_doc_prompt = PromptTemplate(
|
@@ -103,6 +112,8 @@ def get_cite_combine_docs_chain(llm):
|
|
103 |
|
104 |
|
105 |
class RAGChain:
|
|
|
|
|
106 |
def __init__(
|
107 |
self, memory_key="chat_history", output_key="answer", return_messages=True
|
108 |
):
|
@@ -111,14 +122,17 @@ class RAGChain:
|
|
111 |
self.return_messages = return_messages
|
112 |
|
113 |
def create(self, retriever, llm, add_citation=False):
|
114 |
-
|
|
|
|
|
|
|
115 |
k=2,
|
116 |
memory_key=self.memory_key,
|
117 |
return_messages=self.return_messages,
|
118 |
output_key=self.output_key,
|
119 |
)
|
120 |
|
121 |
-
# https://github.com/langchain-ai/langchain/issues/4608
|
122 |
conversation_chain = ConversationalRetrievalChain.from_llm(
|
123 |
llm=llm,
|
124 |
retriever=retriever,
|
@@ -127,7 +141,6 @@ class RAGChain:
|
|
127 |
rephrase_question=False, # disable rephrase, for test purpose
|
128 |
get_chat_history=lambda x: x,
|
129 |
# return_generated_question=True, # for debug
|
130 |
-
# verbose=True,
|
131 |
# combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
|
132 |
# condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
|
133 |
)
|
|
|
1 |
+
"""
|
2 |
+
Main RAG chain based on langchain.
|
3 |
+
"""
|
4 |
+
|
5 |
from langchain.chains import LLMChain
|
6 |
|
7 |
from langchain.prompts import (
|
|
|
15 |
from langchain.chains.conversation.memory import (
|
16 |
ConversationBufferWindowMemory,
|
17 |
)
|
|
|
|
|
18 |
from langchain.chains import StuffDocumentsChain
|
19 |
|
20 |
|
21 |
def get_cite_combine_docs_chain(llm):
|
22 |
+
"""Get doc chain which adds metadata to text chunks."""
|
23 |
|
24 |
# Ref: https://github.com/langchain-ai/langchain/issues/7239
|
25 |
# Function to format each document with an index, source, and content.
|
26 |
def format_document(doc, index, prompt):
|
27 |
"""Format a document into a string based on a prompt template."""
|
28 |
+
|
29 |
# Create a dictionary with document content and metadata.
|
30 |
base_info = {
|
31 |
"page_content": doc.page_content,
|
|
|
44 |
|
45 |
# Custom chain class to handle document combination with source indices.
|
46 |
class StuffDocumentsWithIndexChain(StuffDocumentsChain):
|
47 |
+
"""Custom chain class to handle document combination with source indices."""
|
48 |
+
|
49 |
def _get_inputs(self, docs, **kwargs):
|
50 |
+
"""Overwrite _get_inputs to add metadata for text chunks."""
|
51 |
+
|
52 |
# Format each document and combine them.
|
53 |
doc_strings = [
|
54 |
format_document(doc, i, self.document_prompt)
|
|
|
66 |
)
|
67 |
return inputs
|
68 |
|
69 |
+
# Main prompt for RAG chain with citation
|
70 |
# Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
|
71 |
# Define a chat prompt with instructions for citing documents.
|
72 |
combine_doc_prompt = PromptTemplate(
|
|
|
112 |
|
113 |
|
114 |
class RAGChain:
|
115 |
+
"""Main RAG chain."""
|
116 |
+
|
117 |
def __init__(
|
118 |
self, memory_key="chat_history", output_key="answer", return_messages=True
|
119 |
):
|
|
|
122 |
self.return_messages = return_messages
|
123 |
|
124 |
def create(self, retriever, llm, add_citation=False):
|
125 |
+
"""Create a rag chain instance."""
|
126 |
+
|
127 |
+
# Memory is kept for later support of conversational chat
|
128 |
+
memory = ConversationBufferWindowMemory( # Or ConversationBufferMemory
|
129 |
k=2,
|
130 |
memory_key=self.memory_key,
|
131 |
return_messages=self.return_messages,
|
132 |
output_key=self.output_key,
|
133 |
)
|
134 |
|
135 |
+
# Ref: https://github.com/langchain-ai/langchain/issues/4608
|
136 |
conversation_chain = ConversationalRetrievalChain.from_llm(
|
137 |
llm=llm,
|
138 |
retriever=retriever,
|
|
|
141 |
rephrase_question=False, # disable rephrase, for test purpose
|
142 |
get_chat_history=lambda x: x,
|
143 |
# return_generated_question=True, # for debug
|
|
|
144 |
# combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
|
145 |
# condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
|
146 |
)
|
requirements.txt
CHANGED
@@ -5,7 +5,7 @@ langchain-community==0.2.4
|
|
5 |
text-generation
|
6 |
pypdf
|
7 |
pymupdf
|
8 |
-
gradio
|
9 |
faiss-cpu
|
10 |
chromadb
|
11 |
rank-bm25
|
|
|
5 |
text-generation
|
6 |
pypdf
|
7 |
pymupdf
|
8 |
+
gradio==4.44.1
|
9 |
faiss-cpu
|
10 |
chromadb
|
11 |
rank-bm25
|
rerank.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
"""
|
2 |
-
|
|
|
3 |
https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c
|
4 |
https://github.com/langchain-ai/langchain/issues/13076
|
5 |
"""
|
@@ -7,7 +8,7 @@ https://github.com/langchain-ai/langchain/issues/13076
|
|
7 |
from __future__ import annotations
|
8 |
from typing import Optional, Sequence
|
9 |
from langchain.schema import Document
|
10 |
-
from langchain.pydantic_v1 import Extra
|
11 |
|
12 |
from langchain.callbacks.manager import Callbacks
|
13 |
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
|
1 |
"""
|
2 |
+
Rerank with cross encoder.
|
3 |
+
Ref:
|
4 |
https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c
|
5 |
https://github.com/langchain-ai/langchain/issues/13076
|
6 |
"""
|
|
|
8 |
from __future__ import annotations
|
9 |
from typing import Optional, Sequence
|
10 |
from langchain.schema import Document
|
11 |
+
from langchain.pydantic_v1 import Extra
|
12 |
|
13 |
from langchain.callbacks.manager import Callbacks
|
14 |
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
retrievers.py
CHANGED
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
from langchain.text_splitter import (
|
4 |
-
CharacterTextSplitter,
|
5 |
RecursiveCharacterTextSplitter,
|
6 |
SpacyTextSplitter,
|
7 |
)
|
@@ -9,6 +12,7 @@ from langchain.text_splitter import (
|
|
9 |
from rerank import BgeRerank
|
10 |
from langchain.retrievers import ContextualCompressionRetriever
|
11 |
|
|
|
12 |
def get_parent_doc_retriever(
|
13 |
documents,
|
14 |
vectorstore,
|
@@ -40,12 +44,14 @@ def get_parent_doc_retriever(
|
|
40 |
from langchain_rag.storage import SQLStore
|
41 |
|
42 |
# Instantiate the SQLStore with the root path
|
43 |
-
docstore = SQLStore(
|
|
|
|
|
44 |
else:
|
45 |
docstore = docstore # TODO: add check
|
46 |
-
# raise # TODO implement
|
47 |
|
48 |
-
# TODO: how to better set
|
49 |
# parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
|
50 |
# child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
|
51 |
parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
@@ -62,11 +68,11 @@ def get_parent_doc_retriever(
|
|
62 |
docstore=docstore,
|
63 |
child_splitter=child_splitter,
|
64 |
parent_splitter=parent_splitter,
|
65 |
-
search_kwargs={"k": k},
|
66 |
)
|
67 |
|
68 |
if add_documents:
|
69 |
-
retriever.add_documents(documents)
|
70 |
|
71 |
if save_vectorstore:
|
72 |
vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
|
@@ -80,7 +86,6 @@ def get_parent_doc_retriever(
|
|
80 |
|
81 |
save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
|
82 |
|
83 |
-
|
84 |
return retriever
|
85 |
|
86 |
|
|
|
1 |
+
"""
|
2 |
+
Retrievers for text chunks.
|
3 |
+
"""
|
4 |
+
|
5 |
import os
|
6 |
|
7 |
from langchain.text_splitter import (
|
|
|
8 |
RecursiveCharacterTextSplitter,
|
9 |
SpacyTextSplitter,
|
10 |
)
|
|
|
12 |
from rerank import BgeRerank
|
13 |
from langchain.retrievers import ContextualCompressionRetriever
|
14 |
|
15 |
+
|
16 |
def get_parent_doc_retriever(
|
17 |
documents,
|
18 |
vectorstore,
|
|
|
44 |
from langchain_rag.storage import SQLStore
|
45 |
|
46 |
# Instantiate the SQLStore with the root path
|
47 |
+
docstore = SQLStore(
|
48 |
+
namespace="test", db_url="sqlite:///parent_retrieval_db.db"
|
49 |
+
) # TODO: WIP
|
50 |
else:
|
51 |
docstore = docstore # TODO: add check
|
52 |
+
# raise # TODO implement other docstores
|
53 |
|
54 |
+
# TODO: how to better set these values?
|
55 |
# parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
|
56 |
# child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
|
57 |
parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
|
|
68 |
docstore=docstore,
|
69 |
child_splitter=child_splitter,
|
70 |
parent_splitter=parent_splitter,
|
71 |
+
search_kwargs={"k": k},
|
72 |
)
|
73 |
|
74 |
if add_documents:
|
75 |
+
retriever.add_documents(documents)
|
76 |
|
77 |
if save_vectorstore:
|
78 |
vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
|
|
|
86 |
|
87 |
save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
|
88 |
|
|
|
89 |
return retriever
|
90 |
|
91 |
|
vectorestores.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
def get_faiss_vectorestore(embeddings):
|
4 |
# Add extra text to init
|
5 |
texts = ["LISA - Lithium Ion Solid-state Assistant"]
|
6 |
vectorstore = FAISS.from_texts(texts, embeddings)
|
7 |
-
|
8 |
-
return vectorstore
|
|
|
1 |
+
"""
|
2 |
+
Vector stores.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from langchain.vectorstores import FAISS
|
6 |
+
|
7 |
|
8 |
def get_faiss_vectorestore(embeddings):
|
9 |
# Add extra text to init
|
10 |
texts = ["LISA - Lithium Ion Solid-state Assistant"]
|
11 |
vectorstore = FAISS.from_texts(texts, embeddings)
|
12 |
+
|
13 |
+
return vectorstore
|