Spaces:
Sleeping
Sleeping
import threading | |
import queue | |
import time | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chat_models import ChatOpenAI | |
from langsmith import traceable | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.prompts.chat import SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
from langchain.callbacks.base import BaseCallbackHandler | |
import gradio as gr | |
# -------------------------- | |
# Custom Streaming Callback Handler that uses a queue. | |
# -------------------------- | |
class CustomStreamingCallbackHandler(BaseCallbackHandler): | |
def __init__(self): | |
self.token_queue = queue.Queue() | |
def on_llm_new_token(self, token: str, **kwargs): | |
# Simply put the new token in the queue. | |
self.token_queue.put(token) | |
# Instantiate the custom streaming callback (do not pass the chatbot here). | |
stream_handler = CustomStreamingCallbackHandler() | |
# -------------------------- | |
# Setup vector database and embeddings | |
# -------------------------- | |
embeddings = OpenAIEmbeddings() | |
vectordb = Chroma( | |
persist_directory="./ai_internacionalidad_v2", | |
embedding_function=embeddings | |
) | |
# -------------------------- | |
# Define prompt templates | |
# -------------------------- | |
general_system_template = r""" | |
Cuando indique qué carrera quiere estudiar responde que el 100% de las carreras de UPC cuenta con un programa internacional: https://www.youtube.com/watch?v=-TgOaphgwYg | |
e indica que ellos pueden vivir experiencias como las de alguno de nuestros estudiantes: https://www.youtube.com/watch?v=qFRN6ruRrfc | |
Luego indica que si quiere estudiar en Inglaterra, Australia, Canadá, Nueva Zelandia, Estados Unidos puede hacerlo con el programa NCUK y muestrasles el video: https://youtu.be/C9cKwDt8CwA?si=1NfbsBzaItDIJd5m | |
Luego, si quieren estudiar alguna carrera de las siguientes factultades: arquitectura o Comunicaciones o Derecho o Economia o Ingeniería o Negocios o Psicología, indica que si quieren estudiar desde el Perú en una universidad de Estados Unidos o en Estados Unidos también lo pueden hacer con el programa Arizona: https://youtu.be/jbvMRNEuZUA?si=6FJifohlnoge4VJb | |
Además, en base a la carrera indicada presenta los siguiente videos indicando conoce más sobre como se vive la internacionalización en la carrera de tu interés en caso tengas el link específico de la carrera o en la facultad en caso no tengas el link hacia la carrera: | |
Toma los siguientes documentos de contexto {context} y responde únicamente basado en este contexto. | |
""" | |
general_user_template = "Pregunta:```{question}```" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(general_system_template), | |
HumanMessagePromptTemplate.from_template(general_user_template) | |
] | |
qa_prompt = ChatPromptTemplate.from_messages(messages) | |
# -------------------------- | |
# Create conversation memory | |
# -------------------------- | |
def create_memory(): | |
return ConversationBufferMemory(memory_key='chat_history', return_messages=True) | |
# -------------------------- | |
# Define the chain function that uses the LLM to answer queries | |
# -------------------------- | |
def pdf_qa(query, memory, llm): | |
chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=vectordb.as_retriever(search_kwargs={'k': 28}), | |
combine_docs_chain_kwargs={'prompt': qa_prompt}, | |
memory=memory | |
) | |
return chain({"question": query}) | |
# -------------------------- | |
# Build the Gradio Interface with custom CSS for the "Enviar" button. | |
# -------------------------- | |
with gr.Blocks() as demo: | |
# Inject custom CSS via HTML. | |
gr.HTML( | |
""" | |
<style> | |
/* Target the button inside the container with id "enviar_button" */ | |
#enviar_button button { | |
background-color: #E50A17 !important; | |
color: white !important; | |
} | |
</style> | |
""" | |
) | |
# Chatbot component with an initial greeting. | |
chatbot = gr.Chatbot( | |
label="Internacionalidad", | |
value=[[None, | |
'''¡Hola! | |
Dime la carrera que te interesa y te contaré qué experiencia puedes vivir en el extranjero y como otros alumnos UPC ya estan viviendo esa experiencia. | |
¡Hazme cualquier pregunta y descubramos juntas todas las posibilidades!" | |
''' | |
]] | |
) | |
msg = gr.Textbox(placeholder="Escribe aquí", label='') | |
submit = gr.Button("Enviar", elem_id="enviar_button") | |
memory_state = gr.State(create_memory) | |
# Create the ChatOpenAI model with streaming enabled and our custom callback. | |
llm = ChatOpenAI( | |
temperature=0, | |
model_name='gpt-4o', | |
streaming=True, | |
callbacks=[stream_handler] | |
) | |
# -------------------------- | |
# Generator function that runs the chain in a separate thread and polls the token queue. | |
# -------------------------- | |
def user(query, chat_history, memory): | |
# Append the user's message with an empty bot response. | |
chat_history.append((query, "")) | |
# Immediately yield an update so the user's message appears. | |
yield "", chat_history, memory | |
# Container for the final chain result. | |
final_result = [None] | |
# Define a helper function to run the chain. | |
def run_chain(): | |
result = pdf_qa(query, memory, llm) | |
final_result[0] = result | |
# Signal end-of-stream by putting a sentinel value. | |
stream_handler.token_queue.put(None) | |
# Run the chain in a separate thread. | |
thread = threading.Thread(target=run_chain) | |
thread.start() | |
# Poll the token queue for new tokens and yield updated chat history. | |
current_response = "" | |
while True: | |
try: | |
token = stream_handler.token_queue.get(timeout=0.1) | |
except queue.Empty: | |
token = None | |
# A None token is our signal for end-of-stream. | |
if token is None: | |
if not thread.is_alive(): | |
break | |
else: | |
continue | |
current_response += token | |
chat_history[-1] = (query, current_response) | |
yield "", chat_history, memory | |
thread.join() | |
# Optionally, update the final answer if it differs from the streaming tokens. | |
if final_result[0] and "answer" in final_result[0]: | |
chat_history[-1] = (query, final_result[0]["answer"]) | |
yield "", chat_history, memory | |
# Wire up the generator function to Gradio components with queue enabled. | |
submit.click(user, [msg, chatbot, memory_state], [msg, chatbot, memory_state], queue=True) | |
msg.submit(user, [msg, chatbot, memory_state], [msg, chatbot, memory_state], queue=True) | |
if __name__ == "__main__": | |
demo.queue().launch() | |