SkinCancerDiagnosis / rag_pipeline.py
KeerthiVM's picture
RAG fix
154407c
raw
history blame contribute delete
7.56 kB
import streamlit as st
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from langchain_qdrant import Qdrant
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import SentenceTransformerEmbeddings
from transformers import pipeline
import os
import torch
from groq import Groq
import google.generativeai as genai
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import cohere
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"]
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
Guidelines:
1. Symptoms - Explain in simple terms with proper medical definitions.
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
Query: {question}
Relevant Information: {context}
Answer:
"""
@st.cache_resource(show_spinner=False)
def initialize_rag_components():
components = {
'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]),
'pair_ranker': pipeline("text-classification",
model="llm-blender/PairRM",
tokenizer="llm-blender/PairRM",
return_all_scores=True
),
'gen_fuser': pipeline("text-generation",
model="llm-blender/gen_fuser_3b",
tokenizer="llm-blender/gen_fuser_3b",
max_length=2048,
do_sample=False
),
'retriever': get_retriever()
}
return components
class AllModelsWrapper:
def invoke(self, messages):
prompt = messages[0]["content"]
rag_components = st.session_state.app_models['rag_components'] # Get components
responses = get_all_responses(prompt)
fused = rank_and_fuse(prompt, responses, rag_components)
return type('obj', (object,), {'content': fused})()
def get_all_responses(prompt):
# Get responses from all models
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2,
api_key=st.secrets["OPENAI_API_KEY"]).invoke(
[{"role": "user", "content": prompt}]).content
gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
gemini_resp = gemini.generate_content(prompt).text
llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
llama_resp = llama.chat.completions.create(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=1, max_completion_tokens=1024, top_p=1, stream=False
).choices[0].message.content
return [openai_resp, gemini_resp, llama_resp]
def rank_and_fuse(prompt, responses, rag_components):
ranked = [(resp, rag_components['pair_ranker'](f"{prompt}\n\n{resp}")[0][1]['score'])
for resp in responses]
ranked.sort(key=lambda x: x[1], reverse=True)
# Fuse top responses
fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked[:2])])
return rag_components['gen_fuser'](f"Fuse these responses:\n{fusion_input}",
return_full_text=False)[0]['generated_text']
def get_retriever():
# === Qdrant DB Setup ===
qdrant_client = QdrantClient(
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
)
collection_name = "ks_collection_1.5BE"
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
local_embedding = HuggingFaceEmbeddings(
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
)
print(" Qwen2-1.5B local embedding model loaded.")
vector_store = Qdrant(
client=qdrant_client,
collection_name=collection_name,
embeddings=local_embedding
)
return vector_store.as_retriever()
def initialize_llm(_model_name):
"""Initialize the LLM based on selection"""
print(f"Model name : {_model_name}")
if "OpenAI" in _model_name:
return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
elif "LLaMA" in _model_name:
client = Groq(api_key=st.secrets["GROQ_API_KEY"])
def get_llama_response(prompt):
completion = client.chat.completions.create(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=1,
max_completion_tokens=1024,
top_p=1,
stream=False
)
return completion.choices[0].message.content
return type('obj', (object,), {'invoke': lambda self, x: get_llama_response(x[0]["content"])})()
elif "Gemini" in _model_name:
genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
def get_gemini_response(prompt):
response = gemini_model.generate_content(prompt)
return response.text
return type('obj', (object,), {'invoke': lambda self, x: get_gemini_response(x[0]["content"])})()
elif "Ensemble" in _model_name:
return AllModelsWrapper()
else:
raise ValueError("Unsupported model selected")
def load_rag_chain(llm):
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=get_retriever(),
chain_type="stuff",
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
)
return rag_chain
def rerank_with_cohere(query, documents, co, top_n=5):
if not documents:
return []
raw_texts = [doc.page_content for doc in documents]
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
return [documents[result.index] for result in results.results]
def get_reranked_response(query, llm, rag_components):
"""Get response with reranking"""
docs = rag_components['retriever'].get_relevant_documents(query)
reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client'])
context = "\n\n".join([doc.page_content for doc in reranked_docs])
if isinstance(llm, (ChatOpenAI, AllModelsWrapper)):
return load_rag_chain(llm).invoke({"query": query, "context": context})['result']
else:
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
return llm.invoke([{"role": "user", "content": prompt}]).content
if __name__ == "__main__":
print("This is a module - import it instead of running directly")