# Import necessary libraries import os import gradio as gr from azure.storage.fileshare import ShareServiceClient # Import custom modules from climateqa.engine.embeddings import get_embeddings_function from climateqa.engine.llm import get_llm from climateqa.engine.vectorstore import get_pinecone_vectorstore from climateqa.engine.reranker import get_reranker from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc from climateqa.engine.chains.retrieve_papers import find_papers from climateqa.chat import start_chat, chat_stream, finish_chat from climateqa.engine.talk_to_data.main import ask_vanna from climateqa.engine.talk_to_data.myVanna import MyVanna from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab) from front.utils import process_figures from gradio_modal import Modal from utils import create_user_id import logging logging.basicConfig(level=logging.WARNING) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs logging.getLogger().setLevel(logging.WARNING) # Load environment variables in local mode try: from dotenv import load_dotenv load_dotenv() except Exception as e: pass # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) # Azure Blob Storage credentials account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" credential = { "account_key": account_key, "account_name": os.environ["BLOB_ACCOUNT_NAME"], } account_url = os.environ["BLOB_ACCOUNT_URL"] file_share_name = "climateqa" service = ShareServiceClient(account_url=account_url, credential=credential) share_client = service.get_share_client(file_share_name) user_id = create_user_id() # Create vectorstore and retriever embeddings_function = get_embeddings_function() vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")) vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description") vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")) llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) if os.environ["GRADIO_ENV"] == "local": reranker = get_reranker("nano") else : reranker = get_reranker("large") agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2) agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2 #Vanna object vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4}) db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db") vn.connect_to_sqlite(db_vanna_path) def ask_vanna_query(query): return ask_vanna(vn, db_vanna_path, query) async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): print("chat cqa - message received") async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): yield event async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): print("chat poc - message received") async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): yield event # -------------------------------------------------------------------- # Gradio # -------------------------------------------------------------------- # Function to update modal visibility def update_config_modal_visibility(config_open): print(config_open) new_config_visibility_status = not config_open return Modal(visible=new_config_visibility_status), new_config_visibility_status def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html): sources_number = sources_textbox.count("

") figures_number = figures_cards.count("

") graphs_number = current_graphs.count("") sources_notif_label = f"Sources ({sources_number})" figures_notif_label = f"Figures ({figures_number})" graphs_notif_label = f"Graphs ({graphs_number})" papers_notif_label = f"Papers ({papers_number})" recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})" return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label) def create_drias_tab(): with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna: vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True) with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details : vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False) show_vanna_table = gr.Button("Show Table", elem_id="show-table") with Modal(visible=False) as vanna_table_modal: vanna_table = gr.DataFrame([], elem_id="vanna-table") close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal") close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal]) show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal]) vanna_display = gr.Plot() vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display]) # # UI Layout Components def cqa_tab(tab_name): # State variables current_graphs = gr.State([]) with gr.Tab(tab_name): with gr.Row(elem_id="chatbot-row"): # Left column - Chat interface with gr.Column(scale=2): chatbot, textbox, config_button = create_chat_interface(tab_name) # Right column - Content panels with gr.Column(scale=2, variant="panel", elem_id="right-panel"): with gr.Tabs(elem_id="right_panel_tab") as tabs: # Examples tab with gr.TabItem("Examples", elem_id="tab-examples", id=0): examples_hidden = create_examples_tab(tab_name) # Sources tab with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources: sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") # Recommended content tab with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content: with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content: # Figures subtab with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures: sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab() # Papers subtab with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers: papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab() # Graphs subtab with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs: graphs_container = gr.HTML( "

There are no graphs to be displayed at the moment. Try asking another question.

", elem_id="graphs-container" ) return { "chatbot": chatbot, "textbox": textbox, "tabs": tabs, "sources_raw": sources_raw, "new_figures": new_figures, "current_graphs": current_graphs, "examples_hidden": examples_hidden, "sources_textbox": sources_textbox, "figures_cards": figures_cards, "gallery_component": gallery_component, "config_button": config_button, "papers_direct_search" : papers_direct_search, "papers_html": papers_html, "citations_network": citations_network, "papers_summary": papers_summary, "tab_recommended_content": tab_recommended_content, "tab_sources": tab_sources, "tab_figures": tab_figures, "tab_graphs": tab_graphs, "tab_papers": tab_papers, "graph_container": graphs_container, # "vanna_sql_query": vanna_sql_query, # "vanna_table" : vanna_table, # "vanna_display": vanna_display } def config_event_handling(main_tabs_components : list[dict], config_componenets : dict): config_open = config_componenets["config_open"] config_modal = config_componenets["config_modal"] close_config_modal = config_componenets["close_config_modal_button"] for button in [close_config_modal] + [main_tab_component["config_button"] for main_tab_component in main_tabs_components]: button.click( fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open] ) def event_handling( main_tab_components, config_components, tab_name="ClimateQ&A" ): chatbot = main_tab_components["chatbot"] textbox = main_tab_components["textbox"] tabs = main_tab_components["tabs"] sources_raw = main_tab_components["sources_raw"] new_figures = main_tab_components["new_figures"] current_graphs = main_tab_components["current_graphs"] examples_hidden = main_tab_components["examples_hidden"] sources_textbox = main_tab_components["sources_textbox"] figures_cards = main_tab_components["figures_cards"] gallery_component = main_tab_components["gallery_component"] # config_button = main_tab_components["config_button"] papers_direct_search = main_tab_components["papers_direct_search"] papers_html = main_tab_components["papers_html"] citations_network = main_tab_components["citations_network"] papers_summary = main_tab_components["papers_summary"] tab_recommended_content = main_tab_components["tab_recommended_content"] tab_sources = main_tab_components["tab_sources"] tab_figures = main_tab_components["tab_figures"] tab_graphs = main_tab_components["tab_graphs"] tab_papers = main_tab_components["tab_papers"] graphs_container = main_tab_components["graph_container"] # vanna_sql_query = main_tab_components["vanna_sql_query"] # vanna_table = main_tab_components["vanna_table"] # vanna_display = main_tab_components["vanna_display"] # config_open = config_components["config_open"] # config_modal = config_components["config_modal"] dropdown_sources = config_components["dropdown_sources"] dropdown_reports = config_components["dropdown_reports"] dropdown_external_sources = config_components["dropdown_external_sources"] search_only = config_components["search_only"] dropdown_audience = config_components["dropdown_audience"] after = config_components["after"] output_query = config_components["output_query"] output_language = config_components["output_language"] # close_config_modal = config_components["close_config_modal_button"] new_sources_hmtl = gr.State([]) ttd_data = gr.State([]) # for button in [config_button, close_config_modal]: # button.click( # fn=update_config_modal_visibility, # inputs=[config_open], # outputs=[config_modal, config_open] # ) if tab_name == "ClimateQ&A": print("chat cqa - message sent") # Event for textbox (textbox .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}") .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}") .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}") ) # Event for examples_hidden (examples_hidden .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}") .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}") .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}") ) elif tab_name == "Beta - POC Adapt'Action": print("chat poc - message sent") # Event for textbox (textbox .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}") .then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}") .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}") ) # Event for examples_hidden (examples_hidden .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}") .then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}") .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}") ) new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox]) current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container]) new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component]) # Update sources numbers for component in [sources_textbox, figures_cards, current_graphs, papers_html]: component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers]) # Search for papers for component in [textbox, examples_hidden, papers_direct_search]: component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary]) # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough # # Drias search # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display]) def main_ui(): # config_open = gr.State(True) with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme, elem_id="main-component") as demo: config_components = create_config_modal() with gr.Tabs(): cqa_components = cqa_tab(tab_name = "ClimateQ&A") local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action") create_drias_tab() create_about_tab() event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A') event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action") config_event_handling([cqa_components,local_cqa_components] ,config_components) demo.queue() return demo demo = main_ui() demo.launch(ssr_mode=False)