timeki commited on
Commit
2619e14
·
1 Parent(s): 4e4a906

Merge remote-tracking branch 'origin/feature/improve_parsing_and_retrieval' into pr/20

Browse files
.gitignore CHANGED
@@ -12,7 +12,9 @@ notebooks/
12
  data/
13
  sandbox/
14
 
 
15
  *.db
16
- .vscode/
 
 
17
  *old/
18
- data_ingestion/
 
12
  data/
13
  sandbox/
14
 
15
+ climateqa/talk_to_data/database/
16
  *.db
17
+
18
+ data_ingestion/
19
+ .vscode
20
  *old/
 
app.py CHANGED
@@ -64,7 +64,7 @@ user_id = create_user_id()
64
  embeddings_function = get_embeddings_function()
65
  vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
66
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
67
- vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_REGION"))
68
 
69
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
70
  if os.getenv("ENV")=="GRADIO_ENV":
@@ -73,7 +73,7 @@ else:
73
  reranker = get_reranker("large")
74
 
75
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
76
- agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
77
 
78
 
79
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
@@ -268,7 +268,6 @@ def event_handling(
268
  for component in [textbox, examples_hidden]:
269
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
270
 
271
-
272
 
273
 
274
  def main_ui():
 
64
  embeddings_function = get_embeddings_function()
65
  vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
66
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
67
+ vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
68
 
69
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
70
  if os.getenv("ENV")=="GRADIO_ENV":
 
73
  reranker = get_reranker("large")
74
 
75
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
76
+ agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
77
 
78
 
79
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
268
  for component in [textbox, examples_hidden]:
269
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
270
 
 
271
 
272
 
273
  def main_ui():
climateqa/engine/chains/prompts.py CHANGED
@@ -66,10 +66,11 @@ You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a quest
66
  Guidelines:
67
  - If the passages have useful facts or numbers, use them in your answer.
68
  - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
69
- - You will receive passages from different reports, eg IPCC and PPCP, make separate paragraphs and specify the source of the information in your answer, eg "According to IPCC, ...".
70
- - The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra.
 
71
  - Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
72
- - Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken has verified facts, but as political or strategic decisions.
73
  - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
74
  - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
75
  - If it makes sense, use bullet points and lists to make your answers easier to understand.
@@ -197,4 +198,54 @@ Graphs and their HTML embedding:
197
  {format_instructions}
198
 
199
  Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
200
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  Guidelines:
67
  - If the passages have useful facts or numbers, use them in your answer.
68
  - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
69
+ - You will receive passages from different reports, e.g., IPCC and PPCP. Make separate paragraphs and specify the source of the information in your answer, e.g., "According to IPCC, ...".
70
+ - The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra (Rapport scientifique de la région Nouvelle Aquitaine en France).
71
+ - If the reports are local (like PPCP, PBDP, Acclimaterra), consider that the information is specific to the region and not global. If the document is about a nearby region (for example, an extract from Acclimaterra for a question about Britain), explicitly state the concerned region.
72
  - Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
73
+ - Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken as verified facts, but as political or strategic decisions.
74
  - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
75
  - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
76
  - If it makes sense, use bullet points and lists to make your answers easier to understand.
 
198
  {format_instructions}
199
 
200
  Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
201
+ """
202
+
203
+ retrieve_chapter_prompt_template = """Given the user question and a list of documents with their table of contents, retrieve the 5 most relevant level 0 chapters which could help to answer to the question while taking account their sub-chapters.
204
+
205
+ The table of contents is structured like that :
206
+ {{
207
+ "level": 0,
208
+ "Chapter 1": {{}},
209
+ "Chapter 2" : {{
210
+ "level": 1,
211
+ "Chapter 2.1": {{
212
+ ...
213
+ }}
214
+ }},
215
+ }}
216
+
217
+ Here level is the level of the chapter. For example, Chapter 1 and Chapter 2 are at level 0, and Chapter 2.1 is at level 1.
218
+
219
+ ### Guidelines ###
220
+ - Keep all the list of documents that is given to you
221
+ - Each chapter must keep **EXACTLY** its assigned level in the table of contents. **DO NOT MODIFY THE LEVELS. **
222
+ - Check systematically the level of a chapter before including it in the answer.
223
+ - Return **valid JSON** result.
224
+
225
+ --------------------
226
+ User question :
227
+ {query}
228
+
229
+ List of documents with their table of contents :
230
+ {doc_list}
231
+
232
+ --------------------
233
+
234
+ Return a JSON result with a list of relevant chapters with the following keys **WITHOUT** the json markdown indicator ```json at the beginning:
235
+ - "document" : the document in which we can find the chapter
236
+ - "chapter" : the title of the chapter
237
+
238
+ **IMPORTANT : Make sure that the levels of the answer are exactly the same as the ones in the table of contents**
239
+
240
+ Example of a JSON response:
241
+ [
242
+ {{
243
+ "document": "Document A",
244
+ "chapter": "Chapter 1",
245
+ }},
246
+ {{
247
+ "document": "Document B",
248
+ "chapter": "Chapter 5",
249
+ }}
250
+ ]
251
+ """
climateqa/engine/chains/query_transformation.py CHANGED
@@ -293,6 +293,8 @@ def make_query_transform_node(llm,k_final=15):
293
  "n_questions":n_questions,
294
  "handled_questions_index":[],
295
  }
 
 
296
  return new_state
297
 
298
  return transform_query
 
293
  "n_questions":n_questions,
294
  "handled_questions_index":[],
295
  }
296
+ print("New questions")
297
+ print(new_questions)
298
  return new_state
299
 
300
  return transform_query
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -15,6 +15,14 @@ from ..utils import log_event
15
  from langchain_core.vectorstores import VectorStore
16
  from typing import List
17
  from langchain_core.documents.base import Document
 
 
 
 
 
 
 
 
18
  import asyncio
19
 
20
  from typing import Any, Dict, List, Tuple
@@ -119,6 +127,21 @@ def remove_duplicates_chunks(docs):
119
  result.append(doc)
120
  return result
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  async def get_POC_relevant_documents(
123
  query: str,
124
  vectorstore:VectorStore,
@@ -169,6 +192,86 @@ async def get_POC_relevant_documents(
169
  "docs_question" : docs_question,
170
  "docs_images" : docs_images
171
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  async def get_IPCC_relevant_documents(
@@ -271,6 +374,7 @@ def concatenate_documents(index, source_type, docs_question_dict, k_by_question,
271
  return docs_question, images_question
272
 
273
 
 
274
  # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
275
  # @chain
276
  async def retrieve_documents(
@@ -279,6 +383,7 @@ async def retrieve_documents(
279
  source_type: str,
280
  vectorstore: VectorStore,
281
  reranker: Any,
 
282
  search_figures: bool = False,
283
  search_only: bool = False,
284
  reports: list = [],
@@ -286,7 +391,9 @@ async def retrieve_documents(
286
  k_images_by_question: int = 5,
287
  k_before_reranking: int = 100,
288
  k_by_question: int = 5,
289
- k_summary_by_question: int = 3
 
 
290
  ) -> Tuple[List[Document], List[Document]]:
291
  """
292
  Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
@@ -316,6 +423,7 @@ async def retrieve_documents(
316
 
317
  print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
318
 
 
319
  if source_type == "IPx":
320
  docs_question_dict = await get_IPCC_relevant_documents(
321
  query = question,
@@ -331,19 +439,36 @@ async def retrieve_documents(
331
  reports = reports,
332
  )
333
 
334
- if source_type == "POC":
335
- docs_question_dict = await get_POC_relevant_documents(
336
- query = question,
337
- vectorstore=vectorstore,
338
- search_figures = search_figures,
339
- sources = sources,
340
- threshold = 0.5,
341
- search_only = search_only,
342
- reports = reports,
343
- min_size= 200,
344
- k_documents= k_before_reranking,
345
- k_images= k_by_question
346
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  # Rerank
349
  if reranker is not None and rerank_by_question:
@@ -369,24 +494,44 @@ async def retrieve_documents(
369
  return docs_question, images_question
370
 
371
 
372
- async def retrieve_documents_for_all_questions(state, config, source_type, to_handle_questions_index, vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  """
374
  Retrieve documents in parallel for all questions.
375
  """
376
  # to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
377
 
378
  # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
379
- docs = state.get("documents", [])
380
- related_content = state.get("related_content", [])
381
- search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
382
- search_only = state["search_only"]
383
- reports = state["reports"]
384
-
385
- k_by_question = k_final // state["n_questions"]["total"]
386
- k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
387
- k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
 
 
388
  k_before_reranking=100
389
 
 
390
  tasks = [
391
  retrieve_documents(
392
  current_question=question,
@@ -401,9 +546,12 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
401
  k_images_by_question=k_images_by_question,
402
  k_before_reranking=k_before_reranking,
403
  k_by_question=k_by_question,
404
- k_summary_by_question=k_summary_by_question
 
 
 
405
  )
406
- for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
407
  ]
408
  results = await asyncio.gather(*tasks)
409
  # Combine results
@@ -413,16 +561,50 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
413
  new_state["related_contents"].extend(images_question)
414
  return new_state
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
417
 
418
  async def retrieve_IPx_docs(state, config):
419
  source_type = "IPx"
420
  IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
421
 
422
- # return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
423
-
 
 
 
 
424
  state = await retrieve_documents_for_all_questions(
425
- state=state,
 
 
 
 
426
  config=config,
427
  source_type=source_type,
428
  to_handle_questions_index=IPx_questions_index,
@@ -446,8 +628,18 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
446
  source_type = "POC"
447
  POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
448
 
 
 
 
 
 
 
449
  state = await retrieve_documents_for_all_questions(
450
- state=state,
 
 
 
 
451
  config=config,
452
  source_type=source_type,
453
  to_handle_questions_index=POC_questions_index,
@@ -462,4 +654,56 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
462
  return retrieve_POC_docs_node
463
 
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
 
15
  from langchain_core.vectorstores import VectorStore
16
  from typing import List
17
  from langchain_core.documents.base import Document
18
+ from ..llm import get_llm
19
+ from .prompts import retrieve_chapter_prompt_template
20
+ from langchain_core.prompts import ChatPromptTemplate
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from ..vectorstore import get_pinecone_vectorstore
23
+ from ..embeddings import get_embeddings_function
24
+
25
+
26
  import asyncio
27
 
28
  from typing import Any, Dict, List, Tuple
 
127
  result.append(doc)
128
  return result
129
 
130
+ def get_ToCs(version: str) :
131
+
132
+ filters_text = {
133
+ "chunk_type":"toc",
134
+ "version": version
135
+ }
136
+ embeddings_function = get_embeddings_function()
137
+ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
138
+ tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
139
+
140
+ # remove duplicates or almost duplicates
141
+ tocs = remove_duplicates_chunks(tocs)
142
+
143
+ return tocs
144
+
145
  async def get_POC_relevant_documents(
146
  query: str,
147
  vectorstore:VectorStore,
 
192
  "docs_question" : docs_question,
193
  "docs_images" : docs_images
194
  }
195
+
196
+ async def get_POC_documents_by_ToC_relevant_documents(
197
+ query: str,
198
+ tocs: list,
199
+ vectorstore:VectorStore,
200
+ version: str,
201
+ sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
202
+ search_figures:bool = False,
203
+ search_only:bool = False,
204
+ k_documents:int = 10,
205
+ threshold:float = 0.6,
206
+ k_images: int = 5,
207
+ reports:list = [],
208
+ min_size:int = 200,
209
+ proportion: float = 0.5,
210
+ ) :
211
+ """
212
+ Args:
213
+ - tocs : list with the table of contents of each document
214
+ - version : version of the parsed documents (e.g. "v4")
215
+ - proportion : share of documents retrieved using ToCs
216
+ """
217
+ # Prepare base search kwargs
218
+ filters = {}
219
+ docs_question = []
220
+ docs_images = []
221
+
222
+ # TODO add source selection
223
+ # if len(reports) > 0:
224
+ # filters["short_name"] = {"$in":reports}
225
+ # else:
226
+ # filters["source"] = { "$in": sources}
227
+
228
+ k_documents_toc = round(k_documents * proportion)
229
+
230
+ relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
231
+
232
+ print(f"Relevant ToCs : {relevant_tocs}")
233
+ # Transform the ToC dict {"document": str, "chapter": str} into a list of string
234
+ toc_filters = [toc['chapter'] for toc in relevant_tocs]
235
+
236
+ filters_text_toc = {
237
+ **filters,
238
+ "chunk_type":"text",
239
+ "toc_level0": {"$in": toc_filters},
240
+ "version": version
241
+ # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
242
+ }
243
+
244
+ docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
245
+
246
+ filters_text = {
247
+ **filters,
248
+ "chunk_type":"text",
249
+ "version": version
250
+ # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
251
+ }
252
+
253
+ docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
254
+
255
+ # remove duplicates or almost duplicates
256
+ docs_question = remove_duplicates_chunks(docs_question)
257
+ docs_question = [x for x in docs_question if x[1] > threshold]
258
+
259
+ if search_figures:
260
+ # Images
261
+ filters_image = {
262
+ **filters,
263
+ "chunk_type":"image"
264
+ }
265
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
266
+
267
+ docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
268
+
269
+ docs_question = [x for x in docs_question if len(x.page_content) > min_size]
270
+
271
+ return {
272
+ "docs_question" : docs_question,
273
+ "docs_images" : docs_images
274
+ }
275
 
276
 
277
  async def get_IPCC_relevant_documents(
 
374
  return docs_question, images_question
375
 
376
 
377
+
378
  # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
379
  # @chain
380
  async def retrieve_documents(
 
383
  source_type: str,
384
  vectorstore: VectorStore,
385
  reranker: Any,
386
+ version: str = "",
387
  search_figures: bool = False,
388
  search_only: bool = False,
389
  reports: list = [],
 
391
  k_images_by_question: int = 5,
392
  k_before_reranking: int = 100,
393
  k_by_question: int = 5,
394
+ k_summary_by_question: int = 3,
395
+ tocs: list = [],
396
+ by_toc=False
397
  ) -> Tuple[List[Document], List[Document]]:
398
  """
399
  Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
 
423
 
424
  print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
425
 
426
+
427
  if source_type == "IPx":
428
  docs_question_dict = await get_IPCC_relevant_documents(
429
  query = question,
 
439
  reports = reports,
440
  )
441
 
442
+ if source_type == 'POC':
443
+ if by_toc == True:
444
+ print("---- Retrieve documents by ToC----")
445
+ docs_question_dict = await get_POC_documents_by_ToC_relevant_documents(
446
+ query=question,
447
+ tocs = tocs,
448
+ vectorstore=vectorstore,
449
+ version=version,
450
+ search_figures = search_figures,
451
+ sources = sources,
452
+ threshold = 0.5,
453
+ search_only = search_only,
454
+ reports = reports,
455
+ min_size= 200,
456
+ k_documents= k_before_reranking,
457
+ k_images= k_by_question
458
+ )
459
+ else :
460
+ docs_question_dict = await get_POC_relevant_documents(
461
+ query = question,
462
+ vectorstore=vectorstore,
463
+ search_figures = search_figures,
464
+ sources = sources,
465
+ threshold = 0.5,
466
+ search_only = search_only,
467
+ reports = reports,
468
+ min_size= 200,
469
+ k_documents= k_before_reranking,
470
+ k_images= k_by_question
471
+ )
472
 
473
  # Rerank
474
  if reranker is not None and rerank_by_question:
 
494
  return docs_question, images_question
495
 
496
 
497
+ async def retrieve_documents_for_all_questions(
498
+ search_figures,
499
+ search_only,
500
+ reports,
501
+ questions_list,
502
+ n_questions,
503
+ config,
504
+ source_type,
505
+ to_handle_questions_index,
506
+ vectorstore,
507
+ reranker,
508
+ rerank_by_question=True,
509
+ k_final=15,
510
+ k_before_reranking=100,
511
+ version: str = "",
512
+ tocs: list[dict] = [],
513
+ by_toc: bool = False
514
+ ):
515
  """
516
  Retrieve documents in parallel for all questions.
517
  """
518
  # to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
519
 
520
  # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
521
+ # search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
522
+ # search_only = state["search_only"]
523
+ # reports = state["reports"]
524
+ # questions_list = state["questions_list"]
525
+
526
+ # k_by_question = k_final // state["n_questions"]["total"]
527
+ # k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
528
+ # k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
529
+ k_by_question = k_final // n_questions
530
+ k_summary_by_question = _get_k_summary_by_question(n_questions)
531
+ k_images_by_question = _get_k_images_by_question(n_questions)
532
  k_before_reranking=100
533
 
534
+ print(f"Source type here is {source_type}")
535
  tasks = [
536
  retrieve_documents(
537
  current_question=question,
 
546
  k_images_by_question=k_images_by_question,
547
  k_before_reranking=k_before_reranking,
548
  k_by_question=k_by_question,
549
+ k_summary_by_question=k_summary_by_question,
550
+ tocs=tocs,
551
+ version=version,
552
+ by_toc=by_toc
553
  )
554
+ for i, question in enumerate(questions_list) if i in to_handle_questions_index
555
  ]
556
  results = await asyncio.gather(*tasks)
557
  # Combine results
 
561
  new_state["related_contents"].extend(images_question)
562
  return new_state
563
 
564
+ # ToC Retriever
565
+ async def get_relevant_toc_level_for_query(
566
+ query: str,
567
+ tocs: list[Document],
568
+ ) -> list[dict] :
569
+
570
+ doc_list = []
571
+ for doc in tocs:
572
+ doc_name = doc[0].metadata['name']
573
+ toc = doc[0].page_content
574
+ doc_list.append({'document': doc_name, 'toc': toc})
575
+
576
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
577
+
578
+ prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
579
+ chain = prompt | llm | StrOutputParser()
580
+ response = chain.invoke({"query": query, "doc_list": doc_list})
581
+
582
+ try:
583
+ relevant_tocs = eval(response)
584
+ except Exception as e:
585
+ print(f" Failed to parse the result because of : {e}")
586
+
587
+ return relevant_tocs
588
+
589
+
590
  def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
591
 
592
  async def retrieve_IPx_docs(state, config):
593
  source_type = "IPx"
594
  IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
595
 
596
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
597
+ search_only = state["search_only"]
598
+ reports = state["reports"]
599
+ questions_list = state["questions_list"]
600
+ n_questions=state["n_questions"]["total"]
601
+
602
  state = await retrieve_documents_for_all_questions(
603
+ search_figures=search_figures,
604
+ search_only=search_only,
605
+ reports=reports,
606
+ questions_list=questions_list,
607
+ n_questions=n_questions,
608
  config=config,
609
  source_type=source_type,
610
  to_handle_questions_index=IPx_questions_index,
 
628
  source_type = "POC"
629
  POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
630
 
631
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
632
+ search_only = state["search_only"]
633
+ reports = state["reports"]
634
+ questions_list = state["questions_list"]
635
+ n_questions=state["n_questions"]["total"]
636
+
637
  state = await retrieve_documents_for_all_questions(
638
+ search_figures=search_figures,
639
+ search_only=search_only,
640
+ reports=reports,
641
+ questions_list=questions_list,
642
+ n_questions=n_questions,
643
  config=config,
644
  source_type=source_type,
645
  to_handle_questions_index=POC_questions_index,
 
654
  return retrieve_POC_docs_node
655
 
656
 
657
+ def make_POC_by_ToC_retriever_node(
658
+ vectorstore: VectorStore,
659
+ reranker,
660
+ llm,
661
+ version: str = "",
662
+ rerank_by_question=True,
663
+ k_final=15,
664
+ k_before_reranking=100,
665
+ k_summary=5,
666
+ ):
667
+
668
+ async def retrieve_POC_docs_node(state, config):
669
+ if "POC region" not in state["relevant_content_sources_selection"] :
670
+ return {}
671
+
672
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
673
+ search_only = state["search_only"]
674
+ search_only = state["search_only"]
675
+ reports = state["reports"]
676
+ questions_list = state["questions_list"]
677
+ n_questions=state["n_questions"]["total"]
678
+
679
+ tocs = get_ToCs(version=version)
680
+
681
+ source_type = "POC"
682
+ POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
683
+
684
+ state = await retrieve_documents_for_all_questions(
685
+ search_figures=search_figures,
686
+ search_only=search_only,
687
+ config=config,
688
+ reports=reports,
689
+ questions_list=questions_list,
690
+ n_questions=n_questions,
691
+ source_type=source_type,
692
+ to_handle_questions_index=POC_questions_index,
693
+ vectorstore=vectorstore,
694
+ reranker=reranker,
695
+ rerank_by_question=rerank_by_question,
696
+ k_final=k_final,
697
+ k_before_reranking=k_before_reranking,
698
+ tocs=tocs,
699
+ version=version,
700
+ by_toc=True
701
+ )
702
+ return state
703
+
704
+ return retrieve_POC_docs_node
705
+
706
+
707
+
708
+
709
 
climateqa/engine/graph.py CHANGED
@@ -11,7 +11,7 @@ from typing import List, Dict
11
 
12
  import operator
13
  from typing import Annotated
14
-
15
  from IPython.display import display, HTML, Image
16
 
17
  from .chains.answer_chitchat import make_chitchat_node
@@ -19,7 +19,7 @@ from .chains.answer_ai_impact import make_ai_impact_node
19
  from .chains.query_transformation import make_query_transform_node
20
  from .chains.translation import make_translation_node
21
  from .chains.intent_categorization import make_intent_categorization_node
22
- from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node
23
  from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
@@ -39,14 +39,14 @@ class GraphState(TypedDict):
39
  n_questions : int
40
  answer: str
41
  audience: str = "experts"
42
- sources_input: List[str] = ["IPCC","IPBES"]
43
  relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
44
  sources_auto: bool = True
45
  min_year: int = 1960
46
  max_year: int = None
47
  documents: Annotated[List[Document], operator.add]
48
- related_contents : Annotated[List[Document], operator.add]
49
- recommended_content : List[Document]
50
  search_only : bool = False
51
  reports : List[str] = []
52
 
@@ -72,7 +72,7 @@ def route_intent(state):
72
  def chitchat_route_intent(state):
73
  intent = state["search_graphs_chitchat"]
74
  if intent is True:
75
- return "retrieve_graphs_chitchat"
76
  elif intent is False:
77
  return END
78
 
@@ -95,20 +95,10 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
95
  def route_continue_retrieve_documents(state):
96
  index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
97
  questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
98
- # if questions_ipx_finished and state["search_only"]:
99
- # return END
100
  if questions_ipx_finished:
101
  return "end_retrieve_IPx_documents"
102
  else:
103
  return "retrieve_documents"
104
-
105
-
106
- # if state["n_questions"]["IPx"] == len(state["handled_questions_index"]) and state["search_only"] :
107
- # return END
108
- # elif state["n_questions"]["IPx"] == len(state["handled_questions_index"]):
109
- # return "answer_search"
110
- # else :
111
- # return "retrieve_documents"
112
 
113
  def route_continue_retrieve_local_documents(state):
114
  index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
@@ -120,20 +110,6 @@ def route_continue_retrieve_local_documents(state):
120
  else:
121
  return "retrieve_local_data"
122
 
123
- # if state["n_questions"]["POC"] == len(state["handled_questions_index"]) and state["search_only"] :
124
- # return END
125
- # elif state["n_questions"]["POC"] == len(state["handled_questions_index"]):
126
- # return "answer_search"
127
- # else :
128
- # return "retrieve_local_data"
129
-
130
- # if len(state["remaining_questions"]) == 0 and state["search_only"] :
131
- # return END
132
- # elif len(state["remaining_questions"]) > 0:
133
- # return "retrieve_documents"
134
- # else:
135
- # return "answer_search"
136
-
137
  def route_retrieve_documents(state):
138
  sources_to_retrieve = []
139
 
@@ -232,8 +208,23 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
232
  app = workflow.compile()
233
  return app
234
 
235
- def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
 
237
  workflow = StateGraph(GraphState)
238
 
239
  # Define the node functions
@@ -244,7 +235,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
244
  answer_ai_impact = make_ai_impact_node(llm)
245
  retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
246
  retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
247
- retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
 
248
  answer_rag = make_rag_node(llm, with_docs=True)
249
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
250
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
@@ -315,6 +307,10 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
315
  workflow.add_edge("retrieve_local_data", "answer_search")
316
  workflow.add_edge("retrieve_documents", "answer_search")
317
 
 
 
 
 
318
  # Compile
319
  app = workflow.compile()
320
  return app
 
11
 
12
  import operator
13
  from typing import Annotated
14
+ import pandas as pd
15
  from IPython.display import display, HTML, Image
16
 
17
  from .chains.answer_chitchat import make_chitchat_node
 
19
  from .chains.query_transformation import make_query_transform_node
20
  from .chains.translation import make_translation_node
21
  from .chains.intent_categorization import make_intent_categorization_node
22
+ from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node
23
  from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
 
39
  n_questions : int
40
  answer: str
41
  audience: str = "experts"
42
+ sources_input: List[str] = ["IPCC","IPBES"] # Deprecated -> used only graphs that can only be OWID
43
  relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
44
  sources_auto: bool = True
45
  min_year: int = 1960
46
  max_year: int = None
47
  documents: Annotated[List[Document], operator.add]
48
+ related_contents : Annotated[List[Document], operator.add] # Images
49
+ recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
50
  search_only : bool = False
51
  reports : List[str] = []
52
 
 
72
  def chitchat_route_intent(state):
73
  intent = state["search_graphs_chitchat"]
74
  if intent is True:
75
+ return END #TODO
76
  elif intent is False:
77
  return END
78
 
 
95
  def route_continue_retrieve_documents(state):
96
  index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
97
  questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
 
 
98
  if questions_ipx_finished:
99
  return "end_retrieve_IPx_documents"
100
  else:
101
  return "retrieve_documents"
 
 
 
 
 
 
 
 
102
 
103
  def route_continue_retrieve_local_documents(state):
104
  index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
 
110
  else:
111
  return "retrieve_local_data"
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def route_retrieve_documents(state):
114
  sources_to_retrieve = []
115
 
 
208
  app = workflow.compile()
209
  return app
210
 
211
+ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2):
212
+ """_summary_
213
+
214
+ Args:
215
+ llm (_type_): _description_
216
+ vectorstore_ipcc (_type_): _description_
217
+ vectorstore_graphs (_type_): _description_
218
+ vectorstore_region (_type_): _description_
219
+ reranker (_type_): _description_
220
+ version (str): version of the parsed documents (e.g "v4")
221
+ threshold_docs (float, optional): _description_. Defaults to 0.2.
222
+
223
+ Returns:
224
+ _type_: _description_
225
+ """
226
 
227
+
228
  workflow = StateGraph(GraphState)
229
 
230
  # Define the node functions
 
235
  answer_ai_impact = make_ai_impact_node(llm)
236
  retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
237
  retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
238
+ # retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
239
+ retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version)
240
  answer_rag = make_rag_node(llm, with_docs=True)
241
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
242
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
307
  workflow.add_edge("retrieve_local_data", "answer_search")
308
  workflow.add_edge("retrieve_documents", "answer_search")
309
 
310
+ # workflow.add_edge("transform_query", "retrieve_drias_data")
311
+ # workflow.add_edge("retrieve_drias_data", END)
312
+
313
+
314
  # Compile
315
  app = workflow.compile()
316
  return app
climateqa/engine/talk_to_data/main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.myVanna import MyVanna
2
+ from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
3
+ import sqlite3
4
+ import os
5
+ import pandas as pd
6
+ from climateqa.engine.llm import get_llm
7
+
8
+ from dotenv import load_dotenv
9
+ import ast
10
+
11
+ load_dotenv()
12
+
13
+
14
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
15
+ PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
16
+ INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
17
+ VANNA_MODEL = os.getenv('VANNA_MODEL')
18
+
19
+
20
+ #Vanna object
21
+ vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
22
+ db_vanna_path = os.path.join(os.path.dirname(__file__), "database/drias.db")
23
+ vn.connect_to_sqlite(db_vanna_path)
24
+
25
+ llm = get_llm(provider="openai")
26
+
27
+ def ask_llm_to_add_table_names(sql_query, llm):
28
+ sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
29
+ return sql_with_table_names
30
+
31
+ def ask_llm_column_names(sql_query, llm):
32
+ columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
33
+ columns_list = ast.literal_eval(columns.strip("```python\n").strip())
34
+ return columns_list
35
+
36
+ def ask_vanna(query):
37
+ try :
38
+ location = detect_location_with_openai(OPENAI_API_KEY, query)
39
+ if location:
40
+
41
+ coords = loc2coords(location)
42
+ user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
43
+
44
+ relevant_tables = detect_relevant_tables(user_input, llm)
45
+ coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
46
+ user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
47
+
48
+ sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
49
+
50
+ return sql_query, result_dataframe, figure
51
+
52
+ else :
53
+ empty_df = pd.DataFrame()
54
+ empty_fig = {}
55
+ return "", empty_df, empty_fig
56
+ except Exception as e:
57
+ print(f"Error: {e}")
58
+ empty_df = pd.DataFrame()
59
+ empty_fig = {}
60
+ return "", empty_df, empty_fig
front/tabs/chat_interface.py CHANGED
@@ -20,12 +20,31 @@ Please note that we log your questions for meta-analysis purposes, so avoid shar
20
  What do you want to learn ?
21
  """
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # UI Layout Components
26
- def create_chat_interface():
 
27
  chatbot = gr.Chatbot(
28
- value=[ChatMessage(role="assistant", content=init_prompt)],
29
  type="messages",
30
  show_copy_button=True,
31
  show_label=False,
 
20
  What do you want to learn ?
21
  """
22
 
23
+ init_prompt_poc = """
24
+ Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, PCAET of Paris, the Plan Biodiversité 2018-2024, and Acclimaterra reports from la Région Nouvelle-Aquitaine **.
25
+
26
+ ❓ How to use
27
+ - **Language**: You can ask me your questions in any language.
28
+ - **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
29
+ - **Sources**: You can choose to search in the IPCC or IPBES reports, and POC sources for local documents (PCAET, Plan Biodiversité, Acclimaterra).
30
+ - **Relevant content sources**: You can choose to search for figures, papers, or graphs that can be relevant for your question.
31
+
32
+ ⚠️ Limitations
33
+ *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
34
+
35
+ 🛈 Information
36
+ Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
37
+
38
+ What do you want to learn ?
39
+ """
40
+
41
 
42
 
43
  # UI Layout Components
44
+ def create_chat_interface(tab):
45
+ init_prompt_message = init_prompt_poc if tab == "Beta - POC Adapt'Action" else init_prompt
46
  chatbot = gr.Chatbot(
47
+ value=[ChatMessage(role="assistant", content=init_prompt_message)],
48
  type="messages",
49
  show_copy_button=True,
50
  show_label=False,
front/tabs/main_tab.py CHANGED
@@ -3,7 +3,6 @@ from .chat_interface import create_chat_interface
3
  from .tab_examples import create_examples_tab
4
  from .tab_papers import create_papers_tab
5
  from .tab_figures import create_figures_tab
6
- from .chat_interface import create_chat_interface
7
 
8
  def cqa_tab(tab_name):
9
  # State variables
@@ -12,7 +11,7 @@ def cqa_tab(tab_name):
12
  with gr.Row(elem_id="chatbot-row"):
13
  # Left column - Chat interface
14
  with gr.Column(scale=2):
15
- chatbot, textbox, config_button = create_chat_interface()
16
 
17
  # Right column - Content panels
18
  with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
 
3
  from .tab_examples import create_examples_tab
4
  from .tab_papers import create_papers_tab
5
  from .tab_figures import create_figures_tab
 
6
 
7
  def cqa_tab(tab_name):
8
  # State variables
 
11
  with gr.Row(elem_id="chatbot-row"):
12
  # Left column - Chat interface
13
  with gr.Column(scale=2):
14
+ chatbot, textbox, config_button = create_chat_interface(tab_name)
15
 
16
  # Right column - Content panels
17
  with gr.Column(scale=2, variant="panel", elem_id="right-panel"):