Merge remote-tracking branch 'origin/feature/improve_parsing_and_retrieval' into pr/20
Browse files- .gitignore +4 -2
- app.py +2 -3
- climateqa/engine/chains/prompts.py +55 -4
- climateqa/engine/chains/query_transformation.py +2 -0
- climateqa/engine/chains/retrieve_documents.py +274 -30
- climateqa/engine/graph.py +28 -32
- climateqa/engine/talk_to_data/main.py +60 -0
- front/tabs/chat_interface.py +21 -2
- front/tabs/main_tab.py +1 -2
.gitignore
CHANGED
@@ -12,7 +12,9 @@ notebooks/
|
|
12 |
data/
|
13 |
sandbox/
|
14 |
|
|
|
15 |
*.db
|
16 |
-
|
|
|
|
|
17 |
*old/
|
18 |
-
data_ingestion/
|
|
|
12 |
data/
|
13 |
sandbox/
|
14 |
|
15 |
+
climateqa/talk_to_data/database/
|
16 |
*.db
|
17 |
+
|
18 |
+
data_ingestion/
|
19 |
+
.vscode
|
20 |
*old/
|
|
app.py
CHANGED
@@ -64,7 +64,7 @@ user_id = create_user_id()
|
|
64 |
embeddings_function = get_embeddings_function()
|
65 |
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
|
66 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
67 |
-
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("
|
68 |
|
69 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
70 |
if os.getenv("ENV")=="GRADIO_ENV":
|
@@ -73,7 +73,7 @@ else:
|
|
73 |
reranker = get_reranker("large")
|
74 |
|
75 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
76 |
-
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
|
77 |
|
78 |
|
79 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
@@ -268,7 +268,6 @@ def event_handling(
|
|
268 |
for component in [textbox, examples_hidden]:
|
269 |
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
270 |
|
271 |
-
|
272 |
|
273 |
|
274 |
def main_ui():
|
|
|
64 |
embeddings_function = get_embeddings_function()
|
65 |
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
|
66 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
67 |
+
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
|
68 |
|
69 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
70 |
if os.getenv("ENV")=="GRADIO_ENV":
|
|
|
73 |
reranker = get_reranker("large")
|
74 |
|
75 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
76 |
+
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
|
77 |
|
78 |
|
79 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
|
|
268 |
for component in [textbox, examples_hidden]:
|
269 |
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
270 |
|
|
|
271 |
|
272 |
|
273 |
def main_ui():
|
climateqa/engine/chains/prompts.py
CHANGED
@@ -66,10 +66,11 @@ You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a quest
|
|
66 |
Guidelines:
|
67 |
- If the passages have useful facts or numbers, use them in your answer.
|
68 |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
69 |
-
- You will receive passages from different reports,
|
70 |
-
- The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra.
|
|
|
71 |
- Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
|
72 |
-
- Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken
|
73 |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
74 |
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
75 |
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
@@ -197,4 +198,54 @@ Graphs and their HTML embedding:
|
|
197 |
{format_instructions}
|
198 |
|
199 |
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
200 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
Guidelines:
|
67 |
- If the passages have useful facts or numbers, use them in your answer.
|
68 |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
69 |
+
- You will receive passages from different reports, e.g., IPCC and PPCP. Make separate paragraphs and specify the source of the information in your answer, e.g., "According to IPCC, ...".
|
70 |
+
- The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra (Rapport scientifique de la région Nouvelle Aquitaine en France).
|
71 |
+
- If the reports are local (like PPCP, PBDP, Acclimaterra), consider that the information is specific to the region and not global. If the document is about a nearby region (for example, an extract from Acclimaterra for a question about Britain), explicitly state the concerned region.
|
72 |
- Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
|
73 |
+
- Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken as verified facts, but as political or strategic decisions.
|
74 |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
75 |
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
76 |
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
|
|
198 |
{format_instructions}
|
199 |
|
200 |
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
201 |
+
"""
|
202 |
+
|
203 |
+
retrieve_chapter_prompt_template = """Given the user question and a list of documents with their table of contents, retrieve the 5 most relevant level 0 chapters which could help to answer to the question while taking account their sub-chapters.
|
204 |
+
|
205 |
+
The table of contents is structured like that :
|
206 |
+
{{
|
207 |
+
"level": 0,
|
208 |
+
"Chapter 1": {{}},
|
209 |
+
"Chapter 2" : {{
|
210 |
+
"level": 1,
|
211 |
+
"Chapter 2.1": {{
|
212 |
+
...
|
213 |
+
}}
|
214 |
+
}},
|
215 |
+
}}
|
216 |
+
|
217 |
+
Here level is the level of the chapter. For example, Chapter 1 and Chapter 2 are at level 0, and Chapter 2.1 is at level 1.
|
218 |
+
|
219 |
+
### Guidelines ###
|
220 |
+
- Keep all the list of documents that is given to you
|
221 |
+
- Each chapter must keep **EXACTLY** its assigned level in the table of contents. **DO NOT MODIFY THE LEVELS. **
|
222 |
+
- Check systematically the level of a chapter before including it in the answer.
|
223 |
+
- Return **valid JSON** result.
|
224 |
+
|
225 |
+
--------------------
|
226 |
+
User question :
|
227 |
+
{query}
|
228 |
+
|
229 |
+
List of documents with their table of contents :
|
230 |
+
{doc_list}
|
231 |
+
|
232 |
+
--------------------
|
233 |
+
|
234 |
+
Return a JSON result with a list of relevant chapters with the following keys **WITHOUT** the json markdown indicator ```json at the beginning:
|
235 |
+
- "document" : the document in which we can find the chapter
|
236 |
+
- "chapter" : the title of the chapter
|
237 |
+
|
238 |
+
**IMPORTANT : Make sure that the levels of the answer are exactly the same as the ones in the table of contents**
|
239 |
+
|
240 |
+
Example of a JSON response:
|
241 |
+
[
|
242 |
+
{{
|
243 |
+
"document": "Document A",
|
244 |
+
"chapter": "Chapter 1",
|
245 |
+
}},
|
246 |
+
{{
|
247 |
+
"document": "Document B",
|
248 |
+
"chapter": "Chapter 5",
|
249 |
+
}}
|
250 |
+
]
|
251 |
+
"""
|
climateqa/engine/chains/query_transformation.py
CHANGED
@@ -293,6 +293,8 @@ def make_query_transform_node(llm,k_final=15):
|
|
293 |
"n_questions":n_questions,
|
294 |
"handled_questions_index":[],
|
295 |
}
|
|
|
|
|
296 |
return new_state
|
297 |
|
298 |
return transform_query
|
|
|
293 |
"n_questions":n_questions,
|
294 |
"handled_questions_index":[],
|
295 |
}
|
296 |
+
print("New questions")
|
297 |
+
print(new_questions)
|
298 |
return new_state
|
299 |
|
300 |
return transform_query
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -15,6 +15,14 @@ from ..utils import log_event
|
|
15 |
from langchain_core.vectorstores import VectorStore
|
16 |
from typing import List
|
17 |
from langchain_core.documents.base import Document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
import asyncio
|
19 |
|
20 |
from typing import Any, Dict, List, Tuple
|
@@ -119,6 +127,21 @@ def remove_duplicates_chunks(docs):
|
|
119 |
result.append(doc)
|
120 |
return result
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
async def get_POC_relevant_documents(
|
123 |
query: str,
|
124 |
vectorstore:VectorStore,
|
@@ -169,6 +192,86 @@ async def get_POC_relevant_documents(
|
|
169 |
"docs_question" : docs_question,
|
170 |
"docs_images" : docs_images
|
171 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
|
174 |
async def get_IPCC_relevant_documents(
|
@@ -271,6 +374,7 @@ def concatenate_documents(index, source_type, docs_question_dict, k_by_question,
|
|
271 |
return docs_question, images_question
|
272 |
|
273 |
|
|
|
274 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
275 |
# @chain
|
276 |
async def retrieve_documents(
|
@@ -279,6 +383,7 @@ async def retrieve_documents(
|
|
279 |
source_type: str,
|
280 |
vectorstore: VectorStore,
|
281 |
reranker: Any,
|
|
|
282 |
search_figures: bool = False,
|
283 |
search_only: bool = False,
|
284 |
reports: list = [],
|
@@ -286,7 +391,9 @@ async def retrieve_documents(
|
|
286 |
k_images_by_question: int = 5,
|
287 |
k_before_reranking: int = 100,
|
288 |
k_by_question: int = 5,
|
289 |
-
k_summary_by_question: int = 3
|
|
|
|
|
290 |
) -> Tuple[List[Document], List[Document]]:
|
291 |
"""
|
292 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
@@ -316,6 +423,7 @@ async def retrieve_documents(
|
|
316 |
|
317 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
318 |
|
|
|
319 |
if source_type == "IPx":
|
320 |
docs_question_dict = await get_IPCC_relevant_documents(
|
321 |
query = question,
|
@@ -331,19 +439,36 @@ async def retrieve_documents(
|
|
331 |
reports = reports,
|
332 |
)
|
333 |
|
334 |
-
if source_type ==
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
# Rerank
|
349 |
if reranker is not None and rerank_by_question:
|
@@ -369,24 +494,44 @@ async def retrieve_documents(
|
|
369 |
return docs_question, images_question
|
370 |
|
371 |
|
372 |
-
async def retrieve_documents_for_all_questions(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
"""
|
374 |
Retrieve documents in parallel for all questions.
|
375 |
"""
|
376 |
# to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
|
377 |
|
378 |
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
|
|
|
|
388 |
k_before_reranking=100
|
389 |
|
|
|
390 |
tasks = [
|
391 |
retrieve_documents(
|
392 |
current_question=question,
|
@@ -401,9 +546,12 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
|
|
401 |
k_images_by_question=k_images_by_question,
|
402 |
k_before_reranking=k_before_reranking,
|
403 |
k_by_question=k_by_question,
|
404 |
-
k_summary_by_question=k_summary_by_question
|
|
|
|
|
|
|
405 |
)
|
406 |
-
for i, question in enumerate(
|
407 |
]
|
408 |
results = await asyncio.gather(*tasks)
|
409 |
# Combine results
|
@@ -413,16 +561,50 @@ async def retrieve_documents_for_all_questions(state, config, source_type, to_ha
|
|
413 |
new_state["related_contents"].extend(images_question)
|
414 |
return new_state
|
415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
417 |
|
418 |
async def retrieve_IPx_docs(state, config):
|
419 |
source_type = "IPx"
|
420 |
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
421 |
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
424 |
state = await retrieve_documents_for_all_questions(
|
425 |
-
|
|
|
|
|
|
|
|
|
426 |
config=config,
|
427 |
source_type=source_type,
|
428 |
to_handle_questions_index=IPx_questions_index,
|
@@ -446,8 +628,18 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
446 |
source_type = "POC"
|
447 |
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
state = await retrieve_documents_for_all_questions(
|
450 |
-
|
|
|
|
|
|
|
|
|
451 |
config=config,
|
452 |
source_type=source_type,
|
453 |
to_handle_questions_index=POC_questions_index,
|
@@ -462,4 +654,56 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
462 |
return retrieve_POC_docs_node
|
463 |
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
|
|
15 |
from langchain_core.vectorstores import VectorStore
|
16 |
from typing import List
|
17 |
from langchain_core.documents.base import Document
|
18 |
+
from ..llm import get_llm
|
19 |
+
from .prompts import retrieve_chapter_prompt_template
|
20 |
+
from langchain_core.prompts import ChatPromptTemplate
|
21 |
+
from langchain_core.output_parsers import StrOutputParser
|
22 |
+
from ..vectorstore import get_pinecone_vectorstore
|
23 |
+
from ..embeddings import get_embeddings_function
|
24 |
+
|
25 |
+
|
26 |
import asyncio
|
27 |
|
28 |
from typing import Any, Dict, List, Tuple
|
|
|
127 |
result.append(doc)
|
128 |
return result
|
129 |
|
130 |
+
def get_ToCs(version: str) :
|
131 |
+
|
132 |
+
filters_text = {
|
133 |
+
"chunk_type":"toc",
|
134 |
+
"version": version
|
135 |
+
}
|
136 |
+
embeddings_function = get_embeddings_function()
|
137 |
+
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
|
138 |
+
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
|
139 |
+
|
140 |
+
# remove duplicates or almost duplicates
|
141 |
+
tocs = remove_duplicates_chunks(tocs)
|
142 |
+
|
143 |
+
return tocs
|
144 |
+
|
145 |
async def get_POC_relevant_documents(
|
146 |
query: str,
|
147 |
vectorstore:VectorStore,
|
|
|
192 |
"docs_question" : docs_question,
|
193 |
"docs_images" : docs_images
|
194 |
}
|
195 |
+
|
196 |
+
async def get_POC_documents_by_ToC_relevant_documents(
|
197 |
+
query: str,
|
198 |
+
tocs: list,
|
199 |
+
vectorstore:VectorStore,
|
200 |
+
version: str,
|
201 |
+
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
202 |
+
search_figures:bool = False,
|
203 |
+
search_only:bool = False,
|
204 |
+
k_documents:int = 10,
|
205 |
+
threshold:float = 0.6,
|
206 |
+
k_images: int = 5,
|
207 |
+
reports:list = [],
|
208 |
+
min_size:int = 200,
|
209 |
+
proportion: float = 0.5,
|
210 |
+
) :
|
211 |
+
"""
|
212 |
+
Args:
|
213 |
+
- tocs : list with the table of contents of each document
|
214 |
+
- version : version of the parsed documents (e.g. "v4")
|
215 |
+
- proportion : share of documents retrieved using ToCs
|
216 |
+
"""
|
217 |
+
# Prepare base search kwargs
|
218 |
+
filters = {}
|
219 |
+
docs_question = []
|
220 |
+
docs_images = []
|
221 |
+
|
222 |
+
# TODO add source selection
|
223 |
+
# if len(reports) > 0:
|
224 |
+
# filters["short_name"] = {"$in":reports}
|
225 |
+
# else:
|
226 |
+
# filters["source"] = { "$in": sources}
|
227 |
+
|
228 |
+
k_documents_toc = round(k_documents * proportion)
|
229 |
+
|
230 |
+
relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
|
231 |
+
|
232 |
+
print(f"Relevant ToCs : {relevant_tocs}")
|
233 |
+
# Transform the ToC dict {"document": str, "chapter": str} into a list of string
|
234 |
+
toc_filters = [toc['chapter'] for toc in relevant_tocs]
|
235 |
+
|
236 |
+
filters_text_toc = {
|
237 |
+
**filters,
|
238 |
+
"chunk_type":"text",
|
239 |
+
"toc_level0": {"$in": toc_filters},
|
240 |
+
"version": version
|
241 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
242 |
+
}
|
243 |
+
|
244 |
+
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
|
245 |
+
|
246 |
+
filters_text = {
|
247 |
+
**filters,
|
248 |
+
"chunk_type":"text",
|
249 |
+
"version": version
|
250 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
251 |
+
}
|
252 |
+
|
253 |
+
docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
|
254 |
+
|
255 |
+
# remove duplicates or almost duplicates
|
256 |
+
docs_question = remove_duplicates_chunks(docs_question)
|
257 |
+
docs_question = [x for x in docs_question if x[1] > threshold]
|
258 |
+
|
259 |
+
if search_figures:
|
260 |
+
# Images
|
261 |
+
filters_image = {
|
262 |
+
**filters,
|
263 |
+
"chunk_type":"image"
|
264 |
+
}
|
265 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
266 |
+
|
267 |
+
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
|
268 |
+
|
269 |
+
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
|
270 |
+
|
271 |
+
return {
|
272 |
+
"docs_question" : docs_question,
|
273 |
+
"docs_images" : docs_images
|
274 |
+
}
|
275 |
|
276 |
|
277 |
async def get_IPCC_relevant_documents(
|
|
|
374 |
return docs_question, images_question
|
375 |
|
376 |
|
377 |
+
|
378 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
379 |
# @chain
|
380 |
async def retrieve_documents(
|
|
|
383 |
source_type: str,
|
384 |
vectorstore: VectorStore,
|
385 |
reranker: Any,
|
386 |
+
version: str = "",
|
387 |
search_figures: bool = False,
|
388 |
search_only: bool = False,
|
389 |
reports: list = [],
|
|
|
391 |
k_images_by_question: int = 5,
|
392 |
k_before_reranking: int = 100,
|
393 |
k_by_question: int = 5,
|
394 |
+
k_summary_by_question: int = 3,
|
395 |
+
tocs: list = [],
|
396 |
+
by_toc=False
|
397 |
) -> Tuple[List[Document], List[Document]]:
|
398 |
"""
|
399 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
|
|
423 |
|
424 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
425 |
|
426 |
+
|
427 |
if source_type == "IPx":
|
428 |
docs_question_dict = await get_IPCC_relevant_documents(
|
429 |
query = question,
|
|
|
439 |
reports = reports,
|
440 |
)
|
441 |
|
442 |
+
if source_type == 'POC':
|
443 |
+
if by_toc == True:
|
444 |
+
print("---- Retrieve documents by ToC----")
|
445 |
+
docs_question_dict = await get_POC_documents_by_ToC_relevant_documents(
|
446 |
+
query=question,
|
447 |
+
tocs = tocs,
|
448 |
+
vectorstore=vectorstore,
|
449 |
+
version=version,
|
450 |
+
search_figures = search_figures,
|
451 |
+
sources = sources,
|
452 |
+
threshold = 0.5,
|
453 |
+
search_only = search_only,
|
454 |
+
reports = reports,
|
455 |
+
min_size= 200,
|
456 |
+
k_documents= k_before_reranking,
|
457 |
+
k_images= k_by_question
|
458 |
+
)
|
459 |
+
else :
|
460 |
+
docs_question_dict = await get_POC_relevant_documents(
|
461 |
+
query = question,
|
462 |
+
vectorstore=vectorstore,
|
463 |
+
search_figures = search_figures,
|
464 |
+
sources = sources,
|
465 |
+
threshold = 0.5,
|
466 |
+
search_only = search_only,
|
467 |
+
reports = reports,
|
468 |
+
min_size= 200,
|
469 |
+
k_documents= k_before_reranking,
|
470 |
+
k_images= k_by_question
|
471 |
+
)
|
472 |
|
473 |
# Rerank
|
474 |
if reranker is not None and rerank_by_question:
|
|
|
494 |
return docs_question, images_question
|
495 |
|
496 |
|
497 |
+
async def retrieve_documents_for_all_questions(
|
498 |
+
search_figures,
|
499 |
+
search_only,
|
500 |
+
reports,
|
501 |
+
questions_list,
|
502 |
+
n_questions,
|
503 |
+
config,
|
504 |
+
source_type,
|
505 |
+
to_handle_questions_index,
|
506 |
+
vectorstore,
|
507 |
+
reranker,
|
508 |
+
rerank_by_question=True,
|
509 |
+
k_final=15,
|
510 |
+
k_before_reranking=100,
|
511 |
+
version: str = "",
|
512 |
+
tocs: list[dict] = [],
|
513 |
+
by_toc: bool = False
|
514 |
+
):
|
515 |
"""
|
516 |
Retrieve documents in parallel for all questions.
|
517 |
"""
|
518 |
# to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
|
519 |
|
520 |
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
521 |
+
# search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
522 |
+
# search_only = state["search_only"]
|
523 |
+
# reports = state["reports"]
|
524 |
+
# questions_list = state["questions_list"]
|
525 |
+
|
526 |
+
# k_by_question = k_final // state["n_questions"]["total"]
|
527 |
+
# k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
|
528 |
+
# k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
529 |
+
k_by_question = k_final // n_questions
|
530 |
+
k_summary_by_question = _get_k_summary_by_question(n_questions)
|
531 |
+
k_images_by_question = _get_k_images_by_question(n_questions)
|
532 |
k_before_reranking=100
|
533 |
|
534 |
+
print(f"Source type here is {source_type}")
|
535 |
tasks = [
|
536 |
retrieve_documents(
|
537 |
current_question=question,
|
|
|
546 |
k_images_by_question=k_images_by_question,
|
547 |
k_before_reranking=k_before_reranking,
|
548 |
k_by_question=k_by_question,
|
549 |
+
k_summary_by_question=k_summary_by_question,
|
550 |
+
tocs=tocs,
|
551 |
+
version=version,
|
552 |
+
by_toc=by_toc
|
553 |
)
|
554 |
+
for i, question in enumerate(questions_list) if i in to_handle_questions_index
|
555 |
]
|
556 |
results = await asyncio.gather(*tasks)
|
557 |
# Combine results
|
|
|
561 |
new_state["related_contents"].extend(images_question)
|
562 |
return new_state
|
563 |
|
564 |
+
# ToC Retriever
|
565 |
+
async def get_relevant_toc_level_for_query(
|
566 |
+
query: str,
|
567 |
+
tocs: list[Document],
|
568 |
+
) -> list[dict] :
|
569 |
+
|
570 |
+
doc_list = []
|
571 |
+
for doc in tocs:
|
572 |
+
doc_name = doc[0].metadata['name']
|
573 |
+
toc = doc[0].page_content
|
574 |
+
doc_list.append({'document': doc_name, 'toc': toc})
|
575 |
+
|
576 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
577 |
+
|
578 |
+
prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
|
579 |
+
chain = prompt | llm | StrOutputParser()
|
580 |
+
response = chain.invoke({"query": query, "doc_list": doc_list})
|
581 |
+
|
582 |
+
try:
|
583 |
+
relevant_tocs = eval(response)
|
584 |
+
except Exception as e:
|
585 |
+
print(f" Failed to parse the result because of : {e}")
|
586 |
+
|
587 |
+
return relevant_tocs
|
588 |
+
|
589 |
+
|
590 |
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
591 |
|
592 |
async def retrieve_IPx_docs(state, config):
|
593 |
source_type = "IPx"
|
594 |
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
595 |
|
596 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
597 |
+
search_only = state["search_only"]
|
598 |
+
reports = state["reports"]
|
599 |
+
questions_list = state["questions_list"]
|
600 |
+
n_questions=state["n_questions"]["total"]
|
601 |
+
|
602 |
state = await retrieve_documents_for_all_questions(
|
603 |
+
search_figures=search_figures,
|
604 |
+
search_only=search_only,
|
605 |
+
reports=reports,
|
606 |
+
questions_list=questions_list,
|
607 |
+
n_questions=n_questions,
|
608 |
config=config,
|
609 |
source_type=source_type,
|
610 |
to_handle_questions_index=IPx_questions_index,
|
|
|
628 |
source_type = "POC"
|
629 |
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
630 |
|
631 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
632 |
+
search_only = state["search_only"]
|
633 |
+
reports = state["reports"]
|
634 |
+
questions_list = state["questions_list"]
|
635 |
+
n_questions=state["n_questions"]["total"]
|
636 |
+
|
637 |
state = await retrieve_documents_for_all_questions(
|
638 |
+
search_figures=search_figures,
|
639 |
+
search_only=search_only,
|
640 |
+
reports=reports,
|
641 |
+
questions_list=questions_list,
|
642 |
+
n_questions=n_questions,
|
643 |
config=config,
|
644 |
source_type=source_type,
|
645 |
to_handle_questions_index=POC_questions_index,
|
|
|
654 |
return retrieve_POC_docs_node
|
655 |
|
656 |
|
657 |
+
def make_POC_by_ToC_retriever_node(
|
658 |
+
vectorstore: VectorStore,
|
659 |
+
reranker,
|
660 |
+
llm,
|
661 |
+
version: str = "",
|
662 |
+
rerank_by_question=True,
|
663 |
+
k_final=15,
|
664 |
+
k_before_reranking=100,
|
665 |
+
k_summary=5,
|
666 |
+
):
|
667 |
+
|
668 |
+
async def retrieve_POC_docs_node(state, config):
|
669 |
+
if "POC region" not in state["relevant_content_sources_selection"] :
|
670 |
+
return {}
|
671 |
+
|
672 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
673 |
+
search_only = state["search_only"]
|
674 |
+
search_only = state["search_only"]
|
675 |
+
reports = state["reports"]
|
676 |
+
questions_list = state["questions_list"]
|
677 |
+
n_questions=state["n_questions"]["total"]
|
678 |
+
|
679 |
+
tocs = get_ToCs(version=version)
|
680 |
+
|
681 |
+
source_type = "POC"
|
682 |
+
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
683 |
+
|
684 |
+
state = await retrieve_documents_for_all_questions(
|
685 |
+
search_figures=search_figures,
|
686 |
+
search_only=search_only,
|
687 |
+
config=config,
|
688 |
+
reports=reports,
|
689 |
+
questions_list=questions_list,
|
690 |
+
n_questions=n_questions,
|
691 |
+
source_type=source_type,
|
692 |
+
to_handle_questions_index=POC_questions_index,
|
693 |
+
vectorstore=vectorstore,
|
694 |
+
reranker=reranker,
|
695 |
+
rerank_by_question=rerank_by_question,
|
696 |
+
k_final=k_final,
|
697 |
+
k_before_reranking=k_before_reranking,
|
698 |
+
tocs=tocs,
|
699 |
+
version=version,
|
700 |
+
by_toc=True
|
701 |
+
)
|
702 |
+
return state
|
703 |
+
|
704 |
+
return retrieve_POC_docs_node
|
705 |
+
|
706 |
+
|
707 |
+
|
708 |
+
|
709 |
|
climateqa/engine/graph.py
CHANGED
@@ -11,7 +11,7 @@ from typing import List, Dict
|
|
11 |
|
12 |
import operator
|
13 |
from typing import Annotated
|
14 |
-
|
15 |
from IPython.display import display, HTML, Image
|
16 |
|
17 |
from .chains.answer_chitchat import make_chitchat_node
|
@@ -19,7 +19,7 @@ from .chains.answer_ai_impact import make_ai_impact_node
|
|
19 |
from .chains.query_transformation import make_query_transform_node
|
20 |
from .chains.translation import make_translation_node
|
21 |
from .chains.intent_categorization import make_intent_categorization_node
|
22 |
-
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node
|
23 |
from .chains.answer_rag import make_rag_node
|
24 |
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
@@ -39,14 +39,14 @@ class GraphState(TypedDict):
|
|
39 |
n_questions : int
|
40 |
answer: str
|
41 |
audience: str = "experts"
|
42 |
-
sources_input: List[str] = ["IPCC","IPBES"]
|
43 |
relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
|
44 |
sources_auto: bool = True
|
45 |
min_year: int = 1960
|
46 |
max_year: int = None
|
47 |
documents: Annotated[List[Document], operator.add]
|
48 |
-
related_contents : Annotated[List[Document], operator.add]
|
49 |
-
recommended_content : List[Document]
|
50 |
search_only : bool = False
|
51 |
reports : List[str] = []
|
52 |
|
@@ -72,7 +72,7 @@ def route_intent(state):
|
|
72 |
def chitchat_route_intent(state):
|
73 |
intent = state["search_graphs_chitchat"]
|
74 |
if intent is True:
|
75 |
-
return
|
76 |
elif intent is False:
|
77 |
return END
|
78 |
|
@@ -95,20 +95,10 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
95 |
def route_continue_retrieve_documents(state):
|
96 |
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
97 |
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
98 |
-
# if questions_ipx_finished and state["search_only"]:
|
99 |
-
# return END
|
100 |
if questions_ipx_finished:
|
101 |
return "end_retrieve_IPx_documents"
|
102 |
else:
|
103 |
return "retrieve_documents"
|
104 |
-
|
105 |
-
|
106 |
-
# if state["n_questions"]["IPx"] == len(state["handled_questions_index"]) and state["search_only"] :
|
107 |
-
# return END
|
108 |
-
# elif state["n_questions"]["IPx"] == len(state["handled_questions_index"]):
|
109 |
-
# return "answer_search"
|
110 |
-
# else :
|
111 |
-
# return "retrieve_documents"
|
112 |
|
113 |
def route_continue_retrieve_local_documents(state):
|
114 |
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
@@ -120,20 +110,6 @@ def route_continue_retrieve_local_documents(state):
|
|
120 |
else:
|
121 |
return "retrieve_local_data"
|
122 |
|
123 |
-
# if state["n_questions"]["POC"] == len(state["handled_questions_index"]) and state["search_only"] :
|
124 |
-
# return END
|
125 |
-
# elif state["n_questions"]["POC"] == len(state["handled_questions_index"]):
|
126 |
-
# return "answer_search"
|
127 |
-
# else :
|
128 |
-
# return "retrieve_local_data"
|
129 |
-
|
130 |
-
# if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
131 |
-
# return END
|
132 |
-
# elif len(state["remaining_questions"]) > 0:
|
133 |
-
# return "retrieve_documents"
|
134 |
-
# else:
|
135 |
-
# return "answer_search"
|
136 |
-
|
137 |
def route_retrieve_documents(state):
|
138 |
sources_to_retrieve = []
|
139 |
|
@@ -232,8 +208,23 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
232 |
app = workflow.compile()
|
233 |
return app
|
234 |
|
235 |
-
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
|
|
237 |
workflow = StateGraph(GraphState)
|
238 |
|
239 |
# Define the node functions
|
@@ -244,7 +235,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
244 |
answer_ai_impact = make_ai_impact_node(llm)
|
245 |
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
246 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
247 |
-
retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
|
|
248 |
answer_rag = make_rag_node(llm, with_docs=True)
|
249 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
250 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
@@ -315,6 +307,10 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
315 |
workflow.add_edge("retrieve_local_data", "answer_search")
|
316 |
workflow.add_edge("retrieve_documents", "answer_search")
|
317 |
|
|
|
|
|
|
|
|
|
318 |
# Compile
|
319 |
app = workflow.compile()
|
320 |
return app
|
|
|
11 |
|
12 |
import operator
|
13 |
from typing import Annotated
|
14 |
+
import pandas as pd
|
15 |
from IPython.display import display, HTML, Image
|
16 |
|
17 |
from .chains.answer_chitchat import make_chitchat_node
|
|
|
19 |
from .chains.query_transformation import make_query_transform_node
|
20 |
from .chains.translation import make_translation_node
|
21 |
from .chains.intent_categorization import make_intent_categorization_node
|
22 |
+
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node
|
23 |
from .chains.answer_rag import make_rag_node
|
24 |
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
|
|
39 |
n_questions : int
|
40 |
answer: str
|
41 |
audience: str = "experts"
|
42 |
+
sources_input: List[str] = ["IPCC","IPBES"] # Deprecated -> used only graphs that can only be OWID
|
43 |
relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
|
44 |
sources_auto: bool = True
|
45 |
min_year: int = 1960
|
46 |
max_year: int = None
|
47 |
documents: Annotated[List[Document], operator.add]
|
48 |
+
related_contents : Annotated[List[Document], operator.add] # Images
|
49 |
+
recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
|
50 |
search_only : bool = False
|
51 |
reports : List[str] = []
|
52 |
|
|
|
72 |
def chitchat_route_intent(state):
|
73 |
intent = state["search_graphs_chitchat"]
|
74 |
if intent is True:
|
75 |
+
return END #TODO
|
76 |
elif intent is False:
|
77 |
return END
|
78 |
|
|
|
95 |
def route_continue_retrieve_documents(state):
|
96 |
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
97 |
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
|
|
|
|
98 |
if questions_ipx_finished:
|
99 |
return "end_retrieve_IPx_documents"
|
100 |
else:
|
101 |
return "retrieve_documents"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def route_continue_retrieve_local_documents(state):
|
104 |
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
|
|
110 |
else:
|
111 |
return "retrieve_local_data"
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def route_retrieve_documents(state):
|
114 |
sources_to_retrieve = []
|
115 |
|
|
|
208 |
app = workflow.compile()
|
209 |
return app
|
210 |
|
211 |
+
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2):
|
212 |
+
"""_summary_
|
213 |
+
|
214 |
+
Args:
|
215 |
+
llm (_type_): _description_
|
216 |
+
vectorstore_ipcc (_type_): _description_
|
217 |
+
vectorstore_graphs (_type_): _description_
|
218 |
+
vectorstore_region (_type_): _description_
|
219 |
+
reranker (_type_): _description_
|
220 |
+
version (str): version of the parsed documents (e.g "v4")
|
221 |
+
threshold_docs (float, optional): _description_. Defaults to 0.2.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
_type_: _description_
|
225 |
+
"""
|
226 |
|
227 |
+
|
228 |
workflow = StateGraph(GraphState)
|
229 |
|
230 |
# Define the node functions
|
|
|
235 |
answer_ai_impact = make_ai_impact_node(llm)
|
236 |
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
237 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
238 |
+
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
239 |
+
retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version)
|
240 |
answer_rag = make_rag_node(llm, with_docs=True)
|
241 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
242 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
|
307 |
workflow.add_edge("retrieve_local_data", "answer_search")
|
308 |
workflow.add_edge("retrieve_documents", "answer_search")
|
309 |
|
310 |
+
# workflow.add_edge("transform_query", "retrieve_drias_data")
|
311 |
+
# workflow.add_edge("retrieve_drias_data", END)
|
312 |
+
|
313 |
+
|
314 |
# Compile
|
315 |
app = workflow.compile()
|
316 |
return app
|
climateqa/engine/talk_to_data/main.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.myVanna import MyVanna
|
2 |
+
from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
|
3 |
+
import sqlite3
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
from climateqa.engine.llm import get_llm
|
7 |
+
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import ast
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
|
14 |
+
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
15 |
+
PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
|
16 |
+
INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
|
17 |
+
VANNA_MODEL = os.getenv('VANNA_MODEL')
|
18 |
+
|
19 |
+
|
20 |
+
#Vanna object
|
21 |
+
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
|
22 |
+
db_vanna_path = os.path.join(os.path.dirname(__file__), "database/drias.db")
|
23 |
+
vn.connect_to_sqlite(db_vanna_path)
|
24 |
+
|
25 |
+
llm = get_llm(provider="openai")
|
26 |
+
|
27 |
+
def ask_llm_to_add_table_names(sql_query, llm):
|
28 |
+
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
|
29 |
+
return sql_with_table_names
|
30 |
+
|
31 |
+
def ask_llm_column_names(sql_query, llm):
|
32 |
+
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
|
33 |
+
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
34 |
+
return columns_list
|
35 |
+
|
36 |
+
def ask_vanna(query):
|
37 |
+
try :
|
38 |
+
location = detect_location_with_openai(OPENAI_API_KEY, query)
|
39 |
+
if location:
|
40 |
+
|
41 |
+
coords = loc2coords(location)
|
42 |
+
user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
|
43 |
+
|
44 |
+
relevant_tables = detect_relevant_tables(user_input, llm)
|
45 |
+
coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
|
46 |
+
user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
|
47 |
+
|
48 |
+
sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
|
49 |
+
|
50 |
+
return sql_query, result_dataframe, figure
|
51 |
+
|
52 |
+
else :
|
53 |
+
empty_df = pd.DataFrame()
|
54 |
+
empty_fig = {}
|
55 |
+
return "", empty_df, empty_fig
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Error: {e}")
|
58 |
+
empty_df = pd.DataFrame()
|
59 |
+
empty_fig = {}
|
60 |
+
return "", empty_df, empty_fig
|
front/tabs/chat_interface.py
CHANGED
@@ -20,12 +20,31 @@ Please note that we log your questions for meta-analysis purposes, so avoid shar
|
|
20 |
What do you want to learn ?
|
21 |
"""
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
# UI Layout Components
|
26 |
-
def create_chat_interface():
|
|
|
27 |
chatbot = gr.Chatbot(
|
28 |
-
value=[ChatMessage(role="assistant", content=
|
29 |
type="messages",
|
30 |
show_copy_button=True,
|
31 |
show_label=False,
|
|
|
20 |
What do you want to learn ?
|
21 |
"""
|
22 |
|
23 |
+
init_prompt_poc = """
|
24 |
+
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, PCAET of Paris, the Plan Biodiversité 2018-2024, and Acclimaterra reports from la Région Nouvelle-Aquitaine **.
|
25 |
+
|
26 |
+
❓ How to use
|
27 |
+
- **Language**: You can ask me your questions in any language.
|
28 |
+
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
|
29 |
+
- **Sources**: You can choose to search in the IPCC or IPBES reports, and POC sources for local documents (PCAET, Plan Biodiversité, Acclimaterra).
|
30 |
+
- **Relevant content sources**: You can choose to search for figures, papers, or graphs that can be relevant for your question.
|
31 |
+
|
32 |
+
⚠️ Limitations
|
33 |
+
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
|
34 |
+
|
35 |
+
🛈 Information
|
36 |
+
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
|
37 |
+
|
38 |
+
What do you want to learn ?
|
39 |
+
"""
|
40 |
+
|
41 |
|
42 |
|
43 |
# UI Layout Components
|
44 |
+
def create_chat_interface(tab):
|
45 |
+
init_prompt_message = init_prompt_poc if tab == "Beta - POC Adapt'Action" else init_prompt
|
46 |
chatbot = gr.Chatbot(
|
47 |
+
value=[ChatMessage(role="assistant", content=init_prompt_message)],
|
48 |
type="messages",
|
49 |
show_copy_button=True,
|
50 |
show_label=False,
|
front/tabs/main_tab.py
CHANGED
@@ -3,7 +3,6 @@ from .chat_interface import create_chat_interface
|
|
3 |
from .tab_examples import create_examples_tab
|
4 |
from .tab_papers import create_papers_tab
|
5 |
from .tab_figures import create_figures_tab
|
6 |
-
from .chat_interface import create_chat_interface
|
7 |
|
8 |
def cqa_tab(tab_name):
|
9 |
# State variables
|
@@ -12,7 +11,7 @@ def cqa_tab(tab_name):
|
|
12 |
with gr.Row(elem_id="chatbot-row"):
|
13 |
# Left column - Chat interface
|
14 |
with gr.Column(scale=2):
|
15 |
-
chatbot, textbox, config_button = create_chat_interface()
|
16 |
|
17 |
# Right column - Content panels
|
18 |
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|
|
|
3 |
from .tab_examples import create_examples_tab
|
4 |
from .tab_papers import create_papers_tab
|
5 |
from .tab_figures import create_figures_tab
|
|
|
6 |
|
7 |
def cqa_tab(tab_name):
|
8 |
# State variables
|
|
|
11 |
with gr.Row(elem_id="chatbot-row"):
|
12 |
# Left column - Chat interface
|
13 |
with gr.Column(scale=2):
|
14 |
+
chatbot, textbox, config_button = create_chat_interface(tab_name)
|
15 |
|
16 |
# Right column - Content panels
|
17 |
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|