Spaces:
Running
Running
import streamlit as st | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
from langchain_qdrant import Qdrant | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from transformers import pipeline | |
import os | |
import torch | |
from groq import Groq | |
import google.generativeai as genai | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
import cohere | |
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"] | |
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases. | |
You provide accurate, compassionate, and detailed explanations while using correct medical terminology. | |
Guidelines: | |
1. Symptoms - Explain in simple terms with proper medical definitions. | |
2. Causes - Include genetic, environmental, and lifestyle-related risk factors. | |
3. Medications & Treatments - Provide common prescription and over-the-counter treatments. | |
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist. | |
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately. | |
Query: {question} | |
Relevant Information: {context} | |
Answer: | |
""" | |
def initialize_rag_components(): | |
components = { | |
'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]), | |
'pair_ranker': pipeline("text-classification", | |
model="llm-blender/PairRM", | |
tokenizer="llm-blender/PairRM", | |
return_all_scores=True | |
), | |
'gen_fuser': pipeline("text-generation", | |
model="llm-blender/gen_fuser_3b", | |
tokenizer="llm-blender/gen_fuser_3b", | |
max_length=2048, | |
do_sample=False | |
), | |
'retriever': get_retriever() | |
} | |
return components | |
class AllModelsWrapper: | |
def invoke(self, messages): | |
prompt = messages[0]["content"] | |
rag_components = st.session_state.app_models['rag_components'] # Get components | |
responses = get_all_responses(prompt) | |
fused = rank_and_fuse(prompt, responses, rag_components) | |
return type('obj', (object,), {'content': fused})() | |
def get_all_responses(prompt): | |
# Get responses from all models | |
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, | |
api_key=st.secrets["OPENAI_API_KEY"]).invoke( | |
[{"role": "user", "content": prompt}]).content | |
gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") | |
gemini_resp = gemini.generate_content(prompt).text | |
llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) | |
llama_resp = llama.chat.completions.create( | |
model="meta-llama/llama-4-maverick-17b-128e-instruct", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=1, max_completion_tokens=1024, top_p=1, stream=False | |
).choices[0].message.content | |
return [openai_resp, gemini_resp, llama_resp] | |
def rank_and_fuse(prompt, responses, rag_components): | |
ranked = [(resp, rag_components['pair_ranker'](f"{prompt}\n\n{resp}")[0][1]['score']) | |
for resp in responses] | |
ranked.sort(key=lambda x: x[1], reverse=True) | |
# Fuse top responses | |
fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked[:2])]) | |
return rag_components['gen_fuser'](f"Fuse these responses:\n{fusion_input}", | |
return_full_text=False)[0]['generated_text'] | |
def get_retriever(): | |
# === Qdrant DB Setup === | |
qdrant_client = QdrantClient( | |
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", | |
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" | |
) | |
collection_name = "ks_collection_1.5BE" | |
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True) | |
local_embedding = HuggingFaceEmbeddings( | |
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", | |
model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"} | |
) | |
print(" Qwen2-1.5B local embedding model loaded.") | |
vector_store = Qdrant( | |
client=qdrant_client, | |
collection_name=collection_name, | |
embeddings=local_embedding | |
) | |
return vector_store.as_retriever() | |
def initialize_llm(_model_name): | |
"""Initialize the LLM based on selection""" | |
print(f"Model name : {_model_name}") | |
if "OpenAI" in _model_name: | |
return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]) | |
elif "LLaMA" in _model_name: | |
client = Groq(api_key=st.secrets["GROQ_API_KEY"]) | |
def get_llama_response(prompt): | |
completion = client.chat.completions.create( | |
model="meta-llama/llama-4-maverick-17b-128e-instruct", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=1, | |
max_completion_tokens=1024, | |
top_p=1, | |
stream=False | |
) | |
return completion.choices[0].message.content | |
return type('obj', (object,), {'invoke': lambda self, x: get_llama_response(x[0]["content"])})() | |
elif "Gemini" in _model_name: | |
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) | |
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") | |
def get_gemini_response(prompt): | |
response = gemini_model.generate_content(prompt) | |
return response.text | |
return type('obj', (object,), {'invoke': lambda self, x: get_gemini_response(x[0]["content"])})() | |
elif "Ensemble" in _model_name: | |
return AllModelsWrapper() | |
else: | |
raise ValueError("Unsupported model selected") | |
def load_rag_chain(llm): | |
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=get_retriever(), | |
chain_type="stuff", | |
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"} | |
) | |
return rag_chain | |
def rerank_with_cohere(query, documents, co, top_n=5): | |
if not documents: | |
return [] | |
raw_texts = [doc.page_content for doc in documents] | |
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5") | |
return [documents[result.index] for result in results.results] | |
def get_reranked_response(query, llm, rag_components): | |
"""Get response with reranking""" | |
docs = rag_components['retriever'].get_relevant_documents(query) | |
reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client']) | |
context = "\n\n".join([doc.page_content for doc in reranked_docs]) | |
if isinstance(llm, (ChatOpenAI, AllModelsWrapper)): | |
return load_rag_chain(llm).invoke({"query": query, "context": context})['result'] | |
else: | |
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context) | |
return llm.invoke([{"role": "user", "content": prompt}]).content | |
if __name__ == "__main__": | |
print("This is a module - import it instead of running directly") |