armanddemasson commited on
Commit
8b93891
·
1 Parent(s): 69f7a91

retrieve by ToC

Browse files
climateqa/engine/chains/prompts.py CHANGED
@@ -197,4 +197,58 @@ Graphs and their HTML embedding:
197
  {format_instructions}
198
 
199
  Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  """
 
197
  {format_instructions}
198
 
199
  Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
200
+ """
201
+
202
+ retrieve_chapter_prompt_template = """A partir de la question de l'utilisateur et d'une liste de documents avec leurs tables des matières, retrouve les 5 chapitres de niveau 0 les plus pertinents qui pourraient aider à répondre à la question tout en prenant bien en compte leurs sous-chapitres.
203
+
204
+
205
+ La table des matières est structurée de cette façon :
206
+ {{
207
+ "level": 0,
208
+ "Chapitre 1": {{}},
209
+ "Chapitre 2" : {{
210
+ "level": 1,
211
+ "Chapitre 2.1": {{
212
+ ...
213
+ }}
214
+ }},
215
+ }}
216
+
217
+ Ici level correspond au niveau du chapitre. Ici par exemple Chapitre 1 et Chapitre 2 sont au niveau 0, et Chapitre 2.1 est au niveau 1
218
+
219
+ -----------------------
220
+ Suis impérativement cette guideline ci-dessous.
221
+
222
+ ### Guidelines ###
223
+ - Retiens bien la liste complète des documents.
224
+ - Chaque chapitre doit conserver **EXACTEMENT** le niveau qui lui est attribué dans la table des matières. **NE MODIFIE PAS LES NIVEAUX.**
225
+ - Vérifie systématiquement le niveau d’un chapitre avant de l’inclure dans la réponse.
226
+ - Retourne un résultat **JSON valide**.
227
+
228
+ --------------------
229
+ Question de l'utilisateur :
230
+ {query}
231
+
232
+ Liste des documents avec leurs tables des matières :
233
+ {doc_list}
234
+
235
+ --------------------
236
+
237
+ Retourne le résultat en JSON contenant une liste des chapitres pertinents avec les clés suivantes sans l'indicateur markdown du json:
238
+ - "document" : le document contenant le chapitre
239
+ - "chapter" : le titre du chapitre
240
+
241
+ **IMPORTANT : Assure-toi que les niveaux dans la réponse sont exactement ceux de la table des matières.**
242
+
243
+ Exemple de réponse JSON :
244
+ [
245
+ {{
246
+ "document": "Document A",
247
+ "chapter": "Chapitre 1",
248
+ }},
249
+ {{
250
+ "document": "Document B",
251
+ "chapter": "Chapitre 1.1",
252
+ }}
253
+ ]
254
  """
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -15,6 +15,14 @@ from ..utils import log_event
15
  from langchain_core.vectorstores import VectorStore
16
  from typing import List
17
  from langchain_core.documents.base import Document
 
 
 
 
 
 
 
 
18
  import asyncio
19
 
20
  from typing import Any, Dict, List, Tuple
@@ -280,6 +288,7 @@ async def retrieve_documents(
280
  source_type: str,
281
  vectorstore: VectorStore,
282
  reranker: Any,
 
283
  search_figures: bool = False,
284
  search_only: bool = False,
285
  reports: list = [],
@@ -287,7 +296,9 @@ async def retrieve_documents(
287
  k_images_by_question: int = 5,
288
  k_before_reranking: int = 100,
289
  k_by_question: int = 5,
290
- k_summary_by_question: int = 3
 
 
291
  ) -> Tuple[List[Document], List[Document]]:
292
  """
293
  Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
@@ -317,6 +328,7 @@ async def retrieve_documents(
317
 
318
  print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
319
 
 
320
  if source_type == "IPx":
321
  docs_question_dict = await get_IPCC_relevant_documents(
322
  query = question,
@@ -332,19 +344,39 @@ async def retrieve_documents(
332
  reports = reports,
333
  )
334
 
335
- if source_type == "POC":
336
- docs_question_dict = await get_POC_relevant_documents(
337
- query = question,
338
- vectorstore=vectorstore,
339
- search_figures = search_figures,
340
- sources = sources,
341
- threshold = 0.5,
342
- search_only = search_only,
343
- reports = reports,
344
- min_size= 200,
345
- k_documents= k_before_reranking,
346
- k_images= k_by_question
347
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  # Rerank
350
  if reranker is not None and rerank_by_question:
@@ -370,7 +402,20 @@ async def retrieve_documents(
370
  return docs_question, images_question
371
 
372
 
373
- async def retrieve_documents_for_all_questions(state, config, source_type, to_handle_questions_index, vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  """
375
  Retrieve documents in parallel for all questions.
376
  """
@@ -388,6 +433,7 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
388
  k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
389
  k_before_reranking=100
390
 
 
391
  tasks = [
392
  retrieve_documents(
393
  current_question=question,
@@ -402,7 +448,10 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
402
  k_images_by_question=k_images_by_question,
403
  k_before_reranking=k_before_reranking,
404
  k_by_question=k_by_question,
405
- k_summary_by_question=k_summary_by_question
 
 
 
406
  )
407
  for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
408
  ]
@@ -463,4 +512,164 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
463
  return retrieve_POC_docs_node
464
 
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
 
15
  from langchain_core.vectorstores import VectorStore
16
  from typing import List
17
  from langchain_core.documents.base import Document
18
+ from ..llm import get_llm
19
+ from .prompts import retrieve_chapter_prompt_template
20
+ from langchain_core.prompts import ChatPromptTemplate
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from ..vectorstore import get_pinecone_vectorstore
23
+ from ..embeddings import get_embeddings_function
24
+
25
+
26
  import asyncio
27
 
28
  from typing import Any, Dict, List, Tuple
 
288
  source_type: str,
289
  vectorstore: VectorStore,
290
  reranker: Any,
291
+ version: str = "",
292
  search_figures: bool = False,
293
  search_only: bool = False,
294
  reports: list = [],
 
296
  k_images_by_question: int = 5,
297
  k_before_reranking: int = 100,
298
  k_by_question: int = 5,
299
+ k_summary_by_question: int = 3,
300
+ tocs: list = [],
301
+ by_toc=False
302
  ) -> Tuple[List[Document], List[Document]]:
303
  """
304
  Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
 
328
 
329
  print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
330
 
331
+
332
  if source_type == "IPx":
333
  docs_question_dict = await get_IPCC_relevant_documents(
334
  query = question,
 
344
  reports = reports,
345
  )
346
 
347
+ # if source_type == "POC":
348
+ #
349
+
350
+ if source_type == 'POC':
351
+ if by_toc == True:
352
+ print("---- Retrieve documents by ToC----")
353
+ docs_question_dict = await get_POC_by_ToC_relevant_documents(
354
+ query=question,
355
+ tocs = tocs,
356
+ vectorstore=vectorstore,
357
+ version=version,
358
+ search_figures = search_figures,
359
+ sources = sources,
360
+ threshold = 0.5,
361
+ search_only = search_only,
362
+ reports = reports,
363
+ min_size= 200,
364
+ k_documents= k_before_reranking,
365
+ k_images= k_by_question
366
+ )
367
+ else :
368
+ docs_question_dict = await get_POC_relevant_documents(
369
+ query = question,
370
+ vectorstore=vectorstore,
371
+ search_figures = search_figures,
372
+ sources = sources,
373
+ threshold = 0.5,
374
+ search_only = search_only,
375
+ reports = reports,
376
+ min_size= 200,
377
+ k_documents= k_before_reranking,
378
+ k_images= k_by_question
379
+ )
380
 
381
  # Rerank
382
  if reranker is not None and rerank_by_question:
 
402
  return docs_question, images_question
403
 
404
 
405
+ async def retrieve_documents_for_all_questions(
406
+ state,
407
+ config,
408
+ source_type,
409
+ to_handle_questions_index,
410
+ vectorstore: VectorStore,
411
+ reranker,
412
+ version: str = "",
413
+ rerank_by_question: bool=True,
414
+ k_final: int=15,
415
+ k_before_reranking:int =100,
416
+ tocs=[],
417
+ by_toc=False
418
+ ):
419
  """
420
  Retrieve documents in parallel for all questions.
421
  """
 
433
  k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
434
  k_before_reranking=100
435
 
436
+ print(f"Source type here is {source_type}")
437
  tasks = [
438
  retrieve_documents(
439
  current_question=question,
 
448
  k_images_by_question=k_images_by_question,
449
  k_before_reranking=k_before_reranking,
450
  k_by_question=k_by_question,
451
+ k_summary_by_question=k_summary_by_question,
452
+ tocs=tocs,
453
+ version=version,
454
+ by_toc=by_toc
455
  )
456
  for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
457
  ]
 
512
  return retrieve_POC_docs_node
513
 
514
 
515
+ # ToC Retriever
516
+ async def get_relevant_toc_level_for_query(
517
+ query: str,
518
+ tocs: list[Document],
519
+ ) -> list[dict] :
520
+
521
+ doc_list = []
522
+ for doc in tocs:
523
+ doc_name = doc[0].metadata['name']
524
+ toc = doc[0].page_content
525
+ doc_list.append({'document': doc_name, 'toc': toc})
526
+
527
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
528
+
529
+ prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
530
+ chain = prompt | llm | StrOutputParser()
531
+ response = chain.invoke({"query": query, "doc_list": doc_list})
532
+
533
+ try:
534
+ relevant_tocs = eval(response)
535
+ except Exception as e:
536
+ print(f" Failed to parse the result because of : {e}")
537
+
538
+ return relevant_tocs
539
+
540
+
541
+ async def get_POC_by_ToC_relevant_documents(
542
+ query: str,
543
+ tocs: list[str],
544
+ vectorstore:VectorStore,
545
+ version: str = "",
546
+ sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
547
+ search_figures:bool = False,
548
+ search_only:bool = False,
549
+ k_documents:int = 10,
550
+ threshold:float = 0.6,
551
+ k_images: int = 5,
552
+ reports:list = [],
553
+ min_size:int = 200,
554
+ proportion: float = 0.5,
555
+ ) :
556
+ # Prepare base search kwargs
557
+ filters = {}
558
+ docs_question = []
559
+ docs_images = []
560
+
561
+ # TODO add source selection
562
+ # if len(reports) > 0:
563
+ # filters["short_name"] = {"$in":reports}
564
+ # else:
565
+ # filters["source"] = { "$in": sources}
566
+
567
+ k_documents_toc = round(k_documents * proportion)
568
+
569
+ relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
570
+
571
+ print(f"Relevant ToCs : {relevant_tocs}")
572
+ # Transform the ToC dict {"document": str, "chapter": str} into a list of string
573
+ toc_filters = [toc['chapter'] for toc in relevant_tocs]
574
+
575
+ filters_text_toc = {
576
+ **filters,
577
+ "chunk_type":"text",
578
+ "toc_level0": {"$in": toc_filters},
579
+ "version": version
580
+ # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
581
+ }
582
+
583
+ docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
584
+
585
+ filters_text = {
586
+ **filters,
587
+ "chunk_type":"text",
588
+ "version": version
589
+ # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
590
+ }
591
+
592
+ docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
593
+
594
+ # remove duplicates or almost duplicates
595
+ docs_question = remove_duplicates_chunks(docs_question)
596
+ docs_question = [x for x in docs_question if x[1] > threshold]
597
+
598
+ if search_figures:
599
+ # Images
600
+ filters_image = {
601
+ **filters,
602
+ "chunk_type":"image"
603
+ }
604
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
605
+
606
+ docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
607
+
608
+ docs_question = [x for x in docs_question if len(x.page_content) > min_size]
609
+
610
+ return {
611
+ "docs_question" : docs_question,
612
+ "docs_images" : docs_images
613
+ }
614
+
615
+
616
+ def get_ToCs(version: str) :
617
+
618
+ filters_text = {
619
+ "chunk_type":"toc",
620
+ "version": version
621
+ }
622
+ embeddings_function = get_embeddings_function()
623
+ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
624
+ tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
625
+
626
+ # remove duplicates or almost duplicates
627
+ tocs = remove_duplicates_chunks(tocs)
628
+
629
+ return tocs
630
+
631
+
632
+
633
+ def make_POC_by_ToC_retriever_node(
634
+ vectorstore: VectorStore,
635
+ reranker,
636
+ llm,
637
+ version: str = "",
638
+ rerank_by_question=True,
639
+ k_final=15,
640
+ k_before_reranking=100,
641
+ k_summary=5,
642
+ ):
643
+
644
+ async def retrieve_POC_docs_node(state, config):
645
+ if "POC region" not in state["relevant_content_sources_selection"] :
646
+ return {}
647
+
648
+
649
+ tocs = get_ToCs(version=version)
650
+
651
+ source_type = "POC"
652
+ POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
653
+
654
+ state = await retrieve_documents_for_all_questions(
655
+ state=state,
656
+ config=config,
657
+ source_type=source_type,
658
+ to_handle_questions_index=POC_questions_index,
659
+ vectorstore=vectorstore,
660
+ reranker=reranker,
661
+ rerank_by_question=rerank_by_question,
662
+ k_final=k_final,
663
+ k_before_reranking=k_before_reranking,
664
+ tocs=tocs,
665
+ version=version,
666
+ by_toc=True
667
+ )
668
+ return state
669
+
670
+ return retrieve_POC_docs_node
671
+
672
+
673
+
674
+
675