armanddemasson
commited on
Commit
·
8b93891
1
Parent(s):
69f7a91
retrieve by ToC
Browse files
climateqa/engine/chains/prompts.py
CHANGED
@@ -197,4 +197,58 @@ Graphs and their HTML embedding:
|
|
197 |
{format_instructions}
|
198 |
|
199 |
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
"""
|
|
|
197 |
{format_instructions}
|
198 |
|
199 |
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
200 |
+
"""
|
201 |
+
|
202 |
+
retrieve_chapter_prompt_template = """A partir de la question de l'utilisateur et d'une liste de documents avec leurs tables des matières, retrouve les 5 chapitres de niveau 0 les plus pertinents qui pourraient aider à répondre à la question tout en prenant bien en compte leurs sous-chapitres.
|
203 |
+
|
204 |
+
|
205 |
+
La table des matières est structurée de cette façon :
|
206 |
+
{{
|
207 |
+
"level": 0,
|
208 |
+
"Chapitre 1": {{}},
|
209 |
+
"Chapitre 2" : {{
|
210 |
+
"level": 1,
|
211 |
+
"Chapitre 2.1": {{
|
212 |
+
...
|
213 |
+
}}
|
214 |
+
}},
|
215 |
+
}}
|
216 |
+
|
217 |
+
Ici level correspond au niveau du chapitre. Ici par exemple Chapitre 1 et Chapitre 2 sont au niveau 0, et Chapitre 2.1 est au niveau 1
|
218 |
+
|
219 |
+
-----------------------
|
220 |
+
Suis impérativement cette guideline ci-dessous.
|
221 |
+
|
222 |
+
### Guidelines ###
|
223 |
+
- Retiens bien la liste complète des documents.
|
224 |
+
- Chaque chapitre doit conserver **EXACTEMENT** le niveau qui lui est attribué dans la table des matières. **NE MODIFIE PAS LES NIVEAUX.**
|
225 |
+
- Vérifie systématiquement le niveau d’un chapitre avant de l’inclure dans la réponse.
|
226 |
+
- Retourne un résultat **JSON valide**.
|
227 |
+
|
228 |
+
--------------------
|
229 |
+
Question de l'utilisateur :
|
230 |
+
{query}
|
231 |
+
|
232 |
+
Liste des documents avec leurs tables des matières :
|
233 |
+
{doc_list}
|
234 |
+
|
235 |
+
--------------------
|
236 |
+
|
237 |
+
Retourne le résultat en JSON contenant une liste des chapitres pertinents avec les clés suivantes sans l'indicateur markdown du json:
|
238 |
+
- "document" : le document contenant le chapitre
|
239 |
+
- "chapter" : le titre du chapitre
|
240 |
+
|
241 |
+
**IMPORTANT : Assure-toi que les niveaux dans la réponse sont exactement ceux de la table des matières.**
|
242 |
+
|
243 |
+
Exemple de réponse JSON :
|
244 |
+
[
|
245 |
+
{{
|
246 |
+
"document": "Document A",
|
247 |
+
"chapter": "Chapitre 1",
|
248 |
+
}},
|
249 |
+
{{
|
250 |
+
"document": "Document B",
|
251 |
+
"chapter": "Chapitre 1.1",
|
252 |
+
}}
|
253 |
+
]
|
254 |
"""
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -15,6 +15,14 @@ from ..utils import log_event
|
|
15 |
from langchain_core.vectorstores import VectorStore
|
16 |
from typing import List
|
17 |
from langchain_core.documents.base import Document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
import asyncio
|
19 |
|
20 |
from typing import Any, Dict, List, Tuple
|
@@ -280,6 +288,7 @@ async def retrieve_documents(
|
|
280 |
source_type: str,
|
281 |
vectorstore: VectorStore,
|
282 |
reranker: Any,
|
|
|
283 |
search_figures: bool = False,
|
284 |
search_only: bool = False,
|
285 |
reports: list = [],
|
@@ -287,7 +296,9 @@ async def retrieve_documents(
|
|
287 |
k_images_by_question: int = 5,
|
288 |
k_before_reranking: int = 100,
|
289 |
k_by_question: int = 5,
|
290 |
-
k_summary_by_question: int = 3
|
|
|
|
|
291 |
) -> Tuple[List[Document], List[Document]]:
|
292 |
"""
|
293 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
@@ -317,6 +328,7 @@ async def retrieve_documents(
|
|
317 |
|
318 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
319 |
|
|
|
320 |
if source_type == "IPx":
|
321 |
docs_question_dict = await get_IPCC_relevant_documents(
|
322 |
query = question,
|
@@ -332,19 +344,39 @@ async def retrieve_documents(
|
|
332 |
reports = reports,
|
333 |
)
|
334 |
|
335 |
-
if source_type == "POC":
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
# Rerank
|
350 |
if reranker is not None and rerank_by_question:
|
@@ -370,7 +402,20 @@ async def retrieve_documents(
|
|
370 |
return docs_question, images_question
|
371 |
|
372 |
|
373 |
-
async def retrieve_documents_for_all_questions(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
"""
|
375 |
Retrieve documents in parallel for all questions.
|
376 |
"""
|
@@ -388,6 +433,7 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
|
|
388 |
k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
389 |
k_before_reranking=100
|
390 |
|
|
|
391 |
tasks = [
|
392 |
retrieve_documents(
|
393 |
current_question=question,
|
@@ -402,7 +448,10 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
|
|
402 |
k_images_by_question=k_images_by_question,
|
403 |
k_before_reranking=k_before_reranking,
|
404 |
k_by_question=k_by_question,
|
405 |
-
k_summary_by_question=k_summary_by_question
|
|
|
|
|
|
|
406 |
)
|
407 |
for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
|
408 |
]
|
@@ -463,4 +512,164 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
463 |
return retrieve_POC_docs_node
|
464 |
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
|
|
|
15 |
from langchain_core.vectorstores import VectorStore
|
16 |
from typing import List
|
17 |
from langchain_core.documents.base import Document
|
18 |
+
from ..llm import get_llm
|
19 |
+
from .prompts import retrieve_chapter_prompt_template
|
20 |
+
from langchain_core.prompts import ChatPromptTemplate
|
21 |
+
from langchain_core.output_parsers import StrOutputParser
|
22 |
+
from ..vectorstore import get_pinecone_vectorstore
|
23 |
+
from ..embeddings import get_embeddings_function
|
24 |
+
|
25 |
+
|
26 |
import asyncio
|
27 |
|
28 |
from typing import Any, Dict, List, Tuple
|
|
|
288 |
source_type: str,
|
289 |
vectorstore: VectorStore,
|
290 |
reranker: Any,
|
291 |
+
version: str = "",
|
292 |
search_figures: bool = False,
|
293 |
search_only: bool = False,
|
294 |
reports: list = [],
|
|
|
296 |
k_images_by_question: int = 5,
|
297 |
k_before_reranking: int = 100,
|
298 |
k_by_question: int = 5,
|
299 |
+
k_summary_by_question: int = 3,
|
300 |
+
tocs: list = [],
|
301 |
+
by_toc=False
|
302 |
) -> Tuple[List[Document], List[Document]]:
|
303 |
"""
|
304 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
|
|
328 |
|
329 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
330 |
|
331 |
+
|
332 |
if source_type == "IPx":
|
333 |
docs_question_dict = await get_IPCC_relevant_documents(
|
334 |
query = question,
|
|
|
344 |
reports = reports,
|
345 |
)
|
346 |
|
347 |
+
# if source_type == "POC":
|
348 |
+
#
|
349 |
+
|
350 |
+
if source_type == 'POC':
|
351 |
+
if by_toc == True:
|
352 |
+
print("---- Retrieve documents by ToC----")
|
353 |
+
docs_question_dict = await get_POC_by_ToC_relevant_documents(
|
354 |
+
query=question,
|
355 |
+
tocs = tocs,
|
356 |
+
vectorstore=vectorstore,
|
357 |
+
version=version,
|
358 |
+
search_figures = search_figures,
|
359 |
+
sources = sources,
|
360 |
+
threshold = 0.5,
|
361 |
+
search_only = search_only,
|
362 |
+
reports = reports,
|
363 |
+
min_size= 200,
|
364 |
+
k_documents= k_before_reranking,
|
365 |
+
k_images= k_by_question
|
366 |
+
)
|
367 |
+
else :
|
368 |
+
docs_question_dict = await get_POC_relevant_documents(
|
369 |
+
query = question,
|
370 |
+
vectorstore=vectorstore,
|
371 |
+
search_figures = search_figures,
|
372 |
+
sources = sources,
|
373 |
+
threshold = 0.5,
|
374 |
+
search_only = search_only,
|
375 |
+
reports = reports,
|
376 |
+
min_size= 200,
|
377 |
+
k_documents= k_before_reranking,
|
378 |
+
k_images= k_by_question
|
379 |
+
)
|
380 |
|
381 |
# Rerank
|
382 |
if reranker is not None and rerank_by_question:
|
|
|
402 |
return docs_question, images_question
|
403 |
|
404 |
|
405 |
+
async def retrieve_documents_for_all_questions(
|
406 |
+
state,
|
407 |
+
config,
|
408 |
+
source_type,
|
409 |
+
to_handle_questions_index,
|
410 |
+
vectorstore: VectorStore,
|
411 |
+
reranker,
|
412 |
+
version: str = "",
|
413 |
+
rerank_by_question: bool=True,
|
414 |
+
k_final: int=15,
|
415 |
+
k_before_reranking:int =100,
|
416 |
+
tocs=[],
|
417 |
+
by_toc=False
|
418 |
+
):
|
419 |
"""
|
420 |
Retrieve documents in parallel for all questions.
|
421 |
"""
|
|
|
433 |
k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
434 |
k_before_reranking=100
|
435 |
|
436 |
+
print(f"Source type here is {source_type}")
|
437 |
tasks = [
|
438 |
retrieve_documents(
|
439 |
current_question=question,
|
|
|
448 |
k_images_by_question=k_images_by_question,
|
449 |
k_before_reranking=k_before_reranking,
|
450 |
k_by_question=k_by_question,
|
451 |
+
k_summary_by_question=k_summary_by_question,
|
452 |
+
tocs=tocs,
|
453 |
+
version=version,
|
454 |
+
by_toc=by_toc
|
455 |
)
|
456 |
for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
|
457 |
]
|
|
|
512 |
return retrieve_POC_docs_node
|
513 |
|
514 |
|
515 |
+
# ToC Retriever
|
516 |
+
async def get_relevant_toc_level_for_query(
|
517 |
+
query: str,
|
518 |
+
tocs: list[Document],
|
519 |
+
) -> list[dict] :
|
520 |
+
|
521 |
+
doc_list = []
|
522 |
+
for doc in tocs:
|
523 |
+
doc_name = doc[0].metadata['name']
|
524 |
+
toc = doc[0].page_content
|
525 |
+
doc_list.append({'document': doc_name, 'toc': toc})
|
526 |
+
|
527 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
528 |
+
|
529 |
+
prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
|
530 |
+
chain = prompt | llm | StrOutputParser()
|
531 |
+
response = chain.invoke({"query": query, "doc_list": doc_list})
|
532 |
+
|
533 |
+
try:
|
534 |
+
relevant_tocs = eval(response)
|
535 |
+
except Exception as e:
|
536 |
+
print(f" Failed to parse the result because of : {e}")
|
537 |
+
|
538 |
+
return relevant_tocs
|
539 |
+
|
540 |
+
|
541 |
+
async def get_POC_by_ToC_relevant_documents(
|
542 |
+
query: str,
|
543 |
+
tocs: list[str],
|
544 |
+
vectorstore:VectorStore,
|
545 |
+
version: str = "",
|
546 |
+
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
547 |
+
search_figures:bool = False,
|
548 |
+
search_only:bool = False,
|
549 |
+
k_documents:int = 10,
|
550 |
+
threshold:float = 0.6,
|
551 |
+
k_images: int = 5,
|
552 |
+
reports:list = [],
|
553 |
+
min_size:int = 200,
|
554 |
+
proportion: float = 0.5,
|
555 |
+
) :
|
556 |
+
# Prepare base search kwargs
|
557 |
+
filters = {}
|
558 |
+
docs_question = []
|
559 |
+
docs_images = []
|
560 |
+
|
561 |
+
# TODO add source selection
|
562 |
+
# if len(reports) > 0:
|
563 |
+
# filters["short_name"] = {"$in":reports}
|
564 |
+
# else:
|
565 |
+
# filters["source"] = { "$in": sources}
|
566 |
+
|
567 |
+
k_documents_toc = round(k_documents * proportion)
|
568 |
+
|
569 |
+
relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
|
570 |
+
|
571 |
+
print(f"Relevant ToCs : {relevant_tocs}")
|
572 |
+
# Transform the ToC dict {"document": str, "chapter": str} into a list of string
|
573 |
+
toc_filters = [toc['chapter'] for toc in relevant_tocs]
|
574 |
+
|
575 |
+
filters_text_toc = {
|
576 |
+
**filters,
|
577 |
+
"chunk_type":"text",
|
578 |
+
"toc_level0": {"$in": toc_filters},
|
579 |
+
"version": version
|
580 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
581 |
+
}
|
582 |
+
|
583 |
+
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
|
584 |
+
|
585 |
+
filters_text = {
|
586 |
+
**filters,
|
587 |
+
"chunk_type":"text",
|
588 |
+
"version": version
|
589 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
590 |
+
}
|
591 |
+
|
592 |
+
docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
|
593 |
+
|
594 |
+
# remove duplicates or almost duplicates
|
595 |
+
docs_question = remove_duplicates_chunks(docs_question)
|
596 |
+
docs_question = [x for x in docs_question if x[1] > threshold]
|
597 |
+
|
598 |
+
if search_figures:
|
599 |
+
# Images
|
600 |
+
filters_image = {
|
601 |
+
**filters,
|
602 |
+
"chunk_type":"image"
|
603 |
+
}
|
604 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
605 |
+
|
606 |
+
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
|
607 |
+
|
608 |
+
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
|
609 |
+
|
610 |
+
return {
|
611 |
+
"docs_question" : docs_question,
|
612 |
+
"docs_images" : docs_images
|
613 |
+
}
|
614 |
+
|
615 |
+
|
616 |
+
def get_ToCs(version: str) :
|
617 |
+
|
618 |
+
filters_text = {
|
619 |
+
"chunk_type":"toc",
|
620 |
+
"version": version
|
621 |
+
}
|
622 |
+
embeddings_function = get_embeddings_function()
|
623 |
+
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
|
624 |
+
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
|
625 |
+
|
626 |
+
# remove duplicates or almost duplicates
|
627 |
+
tocs = remove_duplicates_chunks(tocs)
|
628 |
+
|
629 |
+
return tocs
|
630 |
+
|
631 |
+
|
632 |
+
|
633 |
+
def make_POC_by_ToC_retriever_node(
|
634 |
+
vectorstore: VectorStore,
|
635 |
+
reranker,
|
636 |
+
llm,
|
637 |
+
version: str = "",
|
638 |
+
rerank_by_question=True,
|
639 |
+
k_final=15,
|
640 |
+
k_before_reranking=100,
|
641 |
+
k_summary=5,
|
642 |
+
):
|
643 |
+
|
644 |
+
async def retrieve_POC_docs_node(state, config):
|
645 |
+
if "POC region" not in state["relevant_content_sources_selection"] :
|
646 |
+
return {}
|
647 |
+
|
648 |
+
|
649 |
+
tocs = get_ToCs(version=version)
|
650 |
+
|
651 |
+
source_type = "POC"
|
652 |
+
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
653 |
+
|
654 |
+
state = await retrieve_documents_for_all_questions(
|
655 |
+
state=state,
|
656 |
+
config=config,
|
657 |
+
source_type=source_type,
|
658 |
+
to_handle_questions_index=POC_questions_index,
|
659 |
+
vectorstore=vectorstore,
|
660 |
+
reranker=reranker,
|
661 |
+
rerank_by_question=rerank_by_question,
|
662 |
+
k_final=k_final,
|
663 |
+
k_before_reranking=k_before_reranking,
|
664 |
+
tocs=tocs,
|
665 |
+
version=version,
|
666 |
+
by_toc=True
|
667 |
+
)
|
668 |
+
return state
|
669 |
+
|
670 |
+
return retrieve_POC_docs_node
|
671 |
+
|
672 |
+
|
673 |
+
|
674 |
+
|
675 |
|