import sys
import os
from contextlib import contextmanager

from langchain.schema import Document
from langgraph.graph import END, StateGraph
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod

from typing_extensions import TypedDict
from typing import List

from IPython.display import display, HTML, Image

from .chains.answer_chitchat import make_chitchat_node
from .chains.answer_ai_impact import make_ai_impact_node
from .chains.query_transformation import make_query_transform_node
from .chains.translation import make_translation_node
from .chains.intent_categorization import make_intent_categorization_node
from .chains.retrieve_documents import make_retriever_node
from .chains.answer_rag import make_rag_node

class GraphState(TypedDict):
    """
    Represents the state of our graph.
    """
    user_input : str
    language : str
    intent : str
    query: str
    remaining_questions : List[dict]
    n_questions : int
    answer: str
    audience: str = "experts"
    sources_input: List[str] = ["IPCC","IPBES"]
    sources_auto: bool = True
    min_year: int = 1960
    max_year: int = None
    documents: List[Document]

def search(state): #TODO
    return state

def answer_search(state):#TODO
    return state

def route_intent(state):
    intent = state["intent"]
    if intent in ["chitchat","esg"]:
        return "answer_chitchat"
    # elif intent == "ai_impact":
    #     return "answer_ai_impact"
    else:
        # Search route
        return "search"
    
def route_translation(state):
    if state["language"].lower() == "english":
        return "transform_query"
    else:
        return "translate_query"
    
def route_based_on_relevant_docs(state,threshold_docs=0.2):
    docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
    if len(docs) > 0:
        return "answer_rag"
    else:
        return "answer_rag_no_docs"
    

def make_id_dict(values):
    return {k:k for k in values}

def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
    
    workflow = StateGraph(GraphState)

    # Define the node functions
    categorize_intent = make_intent_categorization_node(llm)
    transform_query = make_query_transform_node(llm)
    translate_query = make_translation_node(llm)
    answer_chitchat = make_chitchat_node(llm)
    answer_ai_impact = make_ai_impact_node(llm)
    retrieve_documents = make_retriever_node(vectorstore,reranker,llm)
    answer_rag = make_rag_node(llm,with_docs=True)
    answer_rag_no_docs = make_rag_node(llm,with_docs=False)

    # Define the nodes
    workflow.add_node("categorize_intent", categorize_intent)
    workflow.add_node("search", search)
    workflow.add_node("answer_search", answer_search)
    workflow.add_node("transform_query", transform_query)
    workflow.add_node("translate_query", translate_query)
    workflow.add_node("answer_chitchat", answer_chitchat)
    # workflow.add_node("answer_ai_impact", answer_ai_impact)
    workflow.add_node("retrieve_documents",retrieve_documents)
    workflow.add_node("answer_rag",answer_rag)
    workflow.add_node("answer_rag_no_docs",answer_rag_no_docs)

    # Entry point
    workflow.set_entry_point("categorize_intent")

    # CONDITIONAL EDGES
    workflow.add_conditional_edges(
        "categorize_intent",
        route_intent,
        make_id_dict(["answer_chitchat","search"])
    )

    workflow.add_conditional_edges(
        "search",
        route_translation,
        make_id_dict(["translate_query","transform_query"])
    )
    workflow.add_conditional_edges(
        "retrieve_documents",
        lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
        make_id_dict(["retrieve_documents","answer_search"])
    )

    workflow.add_conditional_edges(
        "answer_search",
        lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
        make_id_dict(["answer_rag","answer_rag_no_docs"])
    )

    # Define the edges
    workflow.add_edge("translate_query", "transform_query")
    workflow.add_edge("transform_query", "retrieve_documents")
    workflow.add_edge("answer_rag", END)
    workflow.add_edge("answer_rag_no_docs", END)
    workflow.add_edge("answer_chitchat", END)
    # workflow.add_edge("answer_ai_impact", END)

    # Compile
    app = workflow.compile()
    return app




def display_graph(app):

    display(
        Image(
            app.get_graph(xray = True).draw_mermaid_png(
                draw_method=MermaidDrawMethod.API,
            )
        )
    )