import os import logging import json import numpy as np import faiss from typing import List, Dict, Any import gradio as gr from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate from langchain.chat_models import ChatOpenAI from langchain import OpenAI from sentence_transformers import SentenceTransformer # Configure logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') # Load API key from environment OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: raise ValueError("API key is missing. Set OPENAI_API_KEY in Hugging Face Secrets.") os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY # logging.debug(f"Using OpenAI API Key: {OPENAI_API_KEY[:5]}... (truncated for security)") # Load FAISS index and chunked data logging.debug("Loading FAISS index and chunked data...") faiss_index = faiss.read_index("fp16_faiss_embeddings.index") with open("all_chunked_data.json", "r") as f: all_chunked_data = json.load(f) logging.debug("FAISS index and chunked data loaded successfully.") # Log random FAISS index for verification random_index = np.random.randint(0, len(all_chunked_data)) logging.debug(f"Random FAISS index verification: {random_index}") logging.debug(f"Corresponding chunk: {all_chunked_data[random_index]['text'][:100]}...") logging.debug("Loading and configuring the embedding model...") model = SentenceTransformer( "dunzhang/stella_en_400M_v5", trust_remote_code=True, device="cpu", config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} ) logging.debug("Embedding model loaded successfully.") # Test embedding model import time start_time = time.time() logging.debug("Testing embedding model with a sample query...") try: query_embedding = model.encode(["test query"], show_progress_bar=False) logging.debug(f"Embedding shape: {query_embedding.shape}") logging.debug(f"Encoding took {time.time() - start_time:.2f} seconds") except Exception as e: logging.error(f"Error in embedding model test: {repr(e)}") logging.error(f"Error details: {str(e)}") import traceback logging.error(f"Traceback: {traceback.format_exc()}") # ======================= # Test Embeddings # ======================= # Check the size of the FAISS index # logging.debug(f"Number of embeddings in FAISS index: {faiss_index.ntotal}") # logging.debug("") # logging.debug("") # # Retrieve embeddings from FAISS index (first 'k' embeddings) # k = 2 # Number of embeddings to retrieve for verification # stored_embeddings = np.zeros((k, 1024), dtype='float32') # 1024 is the embedding dimension # faiss_index.reconstruct_n(0, k, stored_embeddings) # # Compare with original embeddings (for example, the first 5 chunks) # original_embeddings = model.encode(all_chunked_data[:k]) # # Print or compare both to check if they match # logging.debug(f"Original Embeddings: {original_embeddings}") # logging.debug(f"Stored Embeddings from FAISS index: {stored_embeddings}") # logging.debug("") # logging.debug("") # # Query one of the chunks and check if FAISS returns the correct nearest neighbor # query_embedding = model.encode([all_chunked_data[0]]) # Encode the first chunk # D, I = faiss_index.search(np.array(query_embedding, dtype='float32'), k=1) # Search for top-1 match # logging.debug(f"Distance: {D}, Index: {I}") # # Check if the index corresponds to the same chunk # logging.debug(f"Queried Chunk: {all_chunked_data[0]}") # logging.debug(f"Matched Chunk: {all_chunked_data[I[0][0]]}") # logging.debug("") # logging.debug("") # # Check the dimensionality of the FAISS index # logging.debug(f"Dimension of embeddings in FAISS index: {faiss_index.d}") CHUNK_SIZE = 400 # Roughly 400 words CHUNK_OVERLAP = 50 # 50 words overlap LLM_MODEL_NAME = "gpt-4o-mini" # Use latest model "o1-mini" much better but paid LLM_TEMPERATURE = 0 TOP_K_RETRIEVAL = 3 # ======================= # Prompt Configuration # ======================= def create_chat_prompt(): """Create a chat prompt template for the AI model.""" chat_prompt_template = """ You are AQUABOTICA, the most advanced AI assistant specializing in aquaculture information. Given a specific query, analyze the provided context extracted from academic documents, and also use your knowledge to generate a precise and concise answer. Also, If the the context contains some quantitative figures, do mention them. Avoid LaTeX or complex math formatting, use plain text for maths. **Query:** {question} **Context:** {context} **Response:** """ prompt = PromptTemplate( template=chat_prompt_template, input_variables=['context', 'question'] ) chat_prompt = ChatPromptTemplate( input_variables=['context', 'question'], metadata={ 'lc_hub_owner': 'aquabotica', 'lc_hub_repo': 'aquaculture-research', 'lc_hub_commit_hash': 'a7b9c123abc12345f6789e123456def123456789' # Adjust commit hash if required }, messages=[ HumanMessagePromptTemplate(prompt=prompt) ] ) return chat_prompt # ======================= # Metadata Formatting # ======================= def format_metadata(chunk_id: int, all_chunked_data: List[Dict[str, Any]]) -> str: """Format metadata directly from the chunked data for a given chunk ID.""" chunk = all_chunked_data[chunk_id] logging.debug(f"Chunk Retrieved: {chunk['text'][:100]}...") # Print first 100 characters logging.debug(f"Metadata: {chunk['metadata']}") metadata = chunk.get('metadata', {}) return f"Chunk {chunk_id}: {metadata}" # ======================= # Language Model and Retrieval Setup # ======================= def initialize_llm(model_name=LLM_MODEL_NAME, temperature=LLM_TEMPERATURE): """Initialize the language model.""" logging.debug("Initializing LLM model...") return ChatOpenAI(model_name=model_name, temperature=temperature,openai_api_key=OPENAI_API_KEY) def main(QUESTION=""): logging.debug(f"Received user query: {QUESTION}") chat_prompt = create_chat_prompt() llm = initialize_llm() # Query FAISS Index try: logging.debug("Encoding query for FAISS retrieval...") query_embedding = model.encode([QUESTION]) logging.debug(f"Query embedding: {query_embedding[:5]}... (truncated)") D, I = faiss_index.search(np.array(query_embedding, dtype='float32'), k=3) relevant_chunk_ids = I[0] logging.debug(f"Retrieved chunk IDs: {relevant_chunk_ids}, Distances: {D}") relevant_chunks = [all_chunked_data[i]['text'] for i in relevant_chunk_ids] #### #### context_display = "\n\n".join([ f"Chunk {idx+1}: {chunk[:]}...\nMetadata: {all_chunked_data[i]['metadata']}" for idx, (i, chunk) in enumerate(zip(relevant_chunk_ids, relevant_chunks)) ]) #### #### # context = "\n\n".join([f"Retrieved Chunk: {chunk}\nMetadata: {all_chunked_data[i]['metadata']}" for i, chunk in zip(relevant_chunk_ids, relevant_chunks)]) context = " ".join(relevant_chunks) except Exception as e: logging.error(f"Error during FAISS search: {e}") return f"Error during FAISS search: {e}" # Generate Response try: logging.debug("Formatting input for LLM...") prompt_input = chat_prompt.format(context=context, question=QUESTION) logging.debug(f"Formatted prompt: {prompt_input}") result = llm.invoke(prompt_input) answer = result.content if hasattr(result, 'content') else "No answer found." logging.debug("LLM successfully generated response.") except Exception as e: logging.error(f"Error during LLM execution: {e}") return f"Error during LLM execution: {e}" return answer, context_display # relevant_chunks_metadata = [format_metadata(chunk_id, all_chunked_data) for chunk_id in relevant_chunk_ids] # return f"\n{answer}\n\n" + context # return f"\n{answer}\n\n" + "\n"+ "\n".join(relevant_chunks_metadata) # iface = gr.Interface( # fn=main, # inputs="text", # outputs="text", # title="Aquabotica: Aquaculture Chatbot", # description="Ask questions about aquaculture and get answers based on scientific manuals." # ) # if __name__ == "__main__": # logging.debug("Launching Gradio UI...") # iface.launch() # # Updated CSS # custom_css = """ # /* Style for labels across all components */ # .question-input label span, # .solution-output label span, # .metadata-output label span { # font-size: 20px !important; # font-weight: bold !important; # } # /* Style for the submit button */ # .submit-btn button { # background-color: orange !important; # color: black !important; # font-weight: bold !important; # } # /* Preserve newlines and enable horizontal scrolling */ # .metadata-output textarea { # white-space: pre !important; # overflow-x: auto !important; # padding: 8px !important; # } # """ # with gr.Blocks(css=custom_css) as demo: # with gr.Column(): # question_input = gr.Textbox( # label="Ask a Question relevant to provided Aquaculture documents", # lines=2, # placeholder="Enter your question here", # elem_classes="question-input" # ) # submit_btn = gr.Button("Submit", elem_classes="submit-btn") # solution_output = gr.Textbox( # label="Response", # interactive=False, # lines=5, # elem_classes="solution-output" # Added missing class # ) # retrieved_chunks = gr.Textbox( # label="Retrieved Data", # interactive=False, # lines=5, # elem_classes="metadata-output" # ) # submit_btn.click(main, inputs=question_input, outputs=[solution_output, retrieved_chunks]) # demo.launch() custom_css = """ /* Style for labels across all components */ .question-input label span, .solution-output label span, .metadata-output label span { font-size: 20px !important; font-weight: bold !important; color: orange !important; } /* Correct style for the submit button */ .submit-btn button { background-color: orange !important; color: black !important; font-weight: bold !important; border: none !important; border-radius: 8px !important; padding: 10px 20px !important; cursor: pointer !important; } /* Hover effect for submit button */ .submit-btn button:hover { background-color: darkorange !important; } /* Preserve newlines and enable horizontal scrolling in retrieved documents */ .metadata-output textarea { white-space: pre !important; overflow-x: auto !important; padding: 8px !important; } """ with gr.Blocks(css=custom_css) as demo: with gr.Column(): question_input = gr.Textbox( label="Ask a Question", lines=2, placeholder="Enter your question here", elem_classes="question-input" ) submit_btn = gr.Button( "Submit", elem_classes="submit-btn" ) solution_output = gr.Textbox( label="Response", interactive=False, lines=5, elem_classes="solution-output" ) retrieved_chunks = gr.Textbox( label="Retrieved Data/Documents", interactive=False, lines=5, elem_classes="metadata-output" ) submit_btn.click( main, inputs=question_input, outputs=[solution_output, retrieved_chunks] ) demo.launch()