|
import os |
|
import json |
|
import glob |
|
from pathlib import Path |
|
import torch |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from langchain_groq import ChatGroq |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_core.documents import Document |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain.chains import create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
import numpy as np |
|
from sentence_transformers import util |
|
import time |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
try: |
|
|
|
from huggingface_hub.inference_api import InferenceApi |
|
import os |
|
groq_api_key = os.environ.get('GROQ_API_KEY') |
|
|
|
|
|
if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets: |
|
groq_api_key = st.secrets['GROQ_API_KEY'] |
|
|
|
if not groq_api_key: |
|
st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.") |
|
|
|
class MockLLM: |
|
def invoke(self, prompt): |
|
return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."} |
|
llm = MockLLM() |
|
else: |
|
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile") |
|
|
|
except Exception as e: |
|
st.error(f"Error setting up LLM: {str(e)}") |
|
class MockLLM: |
|
def invoke(self, prompt): |
|
return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."} |
|
llm = MockLLM() |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="emilyalsentzer/Bio_ClinicalBERT", |
|
model_kwargs={"device": device} |
|
) |
|
|
|
def load_clinical_data(): |
|
"""Load both flowcharts and patient cases""" |
|
docs = [] |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
try: |
|
|
|
flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart") |
|
if os.path.exists(flowchart_dir): |
|
for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")): |
|
try: |
|
with open(fpath, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
content = f""" |
|
DIAGNOSTIC FLOWCHART: {Path(fpath).stem} |
|
Diagnostic Path: {data.get('diagnostic', 'N/A')} |
|
Key Criteria: {data.get('knowledge', 'N/A')} |
|
""" |
|
docs.append(Document( |
|
page_content=content, |
|
metadata={"source": fpath, "type": "flowchart"} |
|
)) |
|
except Exception as e: |
|
st.warning(f"Error loading flowchart file {fpath}: {str(e)}") |
|
else: |
|
st.warning(f"Flowchart directory not found at {flowchart_dir}") |
|
|
|
|
|
finished_dir = os.path.join(current_dir, "Finished") |
|
if os.path.exists(finished_dir): |
|
for category_dir in glob.glob(os.path.join(finished_dir, "*")): |
|
if os.path.isdir(category_dir): |
|
for case_file in glob.glob(os.path.join(category_dir, "*.json")): |
|
try: |
|
with open(case_file, 'r', encoding='utf-8') as f: |
|
case_data = json.load(f) |
|
notes = "\n".join( |
|
f"{k}: {v}" for k, v in case_data.items() if k.startswith("input") |
|
) |
|
docs.append(Document( |
|
page_content=f""" |
|
PATIENT CASE: {Path(case_file).stem} |
|
Category: {Path(category_dir).name} |
|
Notes: {notes} |
|
""", |
|
metadata={"source": case_file, "type": "patient_case"} |
|
)) |
|
except Exception as e: |
|
st.warning(f"Error loading case file {case_file}: {str(e)}") |
|
else: |
|
st.warning(f"Finished directory not found at {finished_dir}") |
|
|
|
|
|
if not docs: |
|
st.warning("No clinical data files found. Using sample data for demonstration.") |
|
docs.append(Document( |
|
page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes. |
|
This application requires clinical data files to be present in the correct directories. |
|
Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""", |
|
metadata={"source": "sample", "type": "sample"} |
|
)) |
|
except Exception as e: |
|
st.error(f"Error loading clinical data: {str(e)}") |
|
|
|
docs.append(Document( |
|
page_content="Error loading clinical data. This is a fallback document for demonstration purposes.", |
|
metadata={"source": "error", "type": "error"} |
|
)) |
|
return docs |
|
|
|
def build_vectorstore(): |
|
"""Build and return the vectorstore using FAISS""" |
|
documents = load_clinical_data() |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
splits = splitter.split_documents(documents) |
|
vectorstore = FAISS.from_documents(splits, embeddings) |
|
return vectorstore |
|
|
|
|
|
def get_vectorstore_path(): |
|
"""Get the path for saving/loading the vectorstore""" |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
return os.path.join(current_dir, "vectorstore") |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading clinical knowledge base...") |
|
def get_vectorstore(): |
|
"""Get or create the vectorstore with disk persistence""" |
|
vectorstore_path = get_vectorstore_path() |
|
|
|
|
|
try: |
|
if os.path.exists(vectorstore_path): |
|
st.info("Loading vectorstore from disk...") |
|
|
|
return FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True) |
|
except Exception as e: |
|
st.warning(f"Could not load vectorstore from disk: {str(e)}. Building new vectorstore.") |
|
|
|
|
|
st.info("Building new vectorstore...") |
|
vectorstore = build_vectorstore() |
|
|
|
|
|
try: |
|
os.makedirs(vectorstore_path, exist_ok=True) |
|
vectorstore.save_local(vectorstore_path) |
|
st.success("Vectorstore saved to disk for future use") |
|
except Exception as e: |
|
st.warning(f"Could not save vectorstore to disk: {str(e)}") |
|
|
|
return vectorstore |
|
|
|
def run_rag_chat(query, vectorstore): |
|
"""Run the Retrieval-Augmented Generation (RAG) for clinical questions""" |
|
try: |
|
retriever = vectorstore.as_retriever() |
|
|
|
prompt_template = ChatPromptTemplate.from_template(""" |
|
You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question. |
|
|
|
<context> |
|
{context} |
|
</context> |
|
|
|
Question: {input} |
|
|
|
Answer: |
|
""") |
|
|
|
retrieved_docs = retriever.invoke(query, k=3) |
|
retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs]) |
|
|
|
|
|
document_chain = create_stuff_documents_chain(llm, prompt_template) |
|
|
|
|
|
chain = create_retrieval_chain(retriever, document_chain) |
|
|
|
|
|
response = chain.invoke({"input": query}) |
|
|
|
|
|
response["context"] = retrieved_docs |
|
|
|
return response |
|
except Exception as e: |
|
st.error(f"Error in RAG processing: {str(e)}") |
|
|
|
return { |
|
"answer": f"I encountered an error processing your query: {str(e)}", |
|
"context": [], |
|
"input": query |
|
} |
|
|
|
def calculate_hit_rate(retriever, query, expected_docs, k=3): |
|
"""Calculate the hit rate for top-k retrieved documents""" |
|
retrieved_docs = retriever.get_relevant_documents(query, k=k) |
|
retrieved_contents = [doc.page_content for doc in retrieved_docs] |
|
|
|
hits = 0 |
|
for expected in expected_docs: |
|
if any(expected in retrieved for retrieved in retrieved_contents): |
|
hits += 1 |
|
|
|
return hits / len(expected_docs) if expected_docs else 0.0 |
|
|
|
def evaluate_rag_response(response, embeddings): |
|
"""Evaluate the RAG response for faithfulness and hit rate""" |
|
scores = {} |
|
|
|
|
|
answer_embed = embeddings.embed_query(response["answer"]) |
|
context_embeds = [embeddings.embed_query(doc.page_content) for doc in response["context"]] |
|
similarities = [util.cos_sim(answer_embed, ctx_embed).item() for ctx_embed in context_embeds] |
|
scores["faithfulness"] = float(np.mean(similarities)) if similarities else 0.0 |
|
|
|
|
|
retriever = response["retriever"] |
|
scores["hit_rate"] = calculate_hit_rate( |
|
retriever, |
|
query=response["input"], |
|
expected_docs=[doc.page_content for doc in response["context"]], |
|
k=3 |
|
) |
|
|
|
return scores |
|
|
|
def main(): |
|
"""Main function to run the Streamlit app""" |
|
|
|
st.set_page_config( |
|
page_title="DiReCT - Clinical Diagnostic Assistant", |
|
page_icon="π©Ί", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
if 'vectorstore' not in st.session_state: |
|
with st.spinner("Loading clinical knowledge base... This may take a minute."): |
|
try: |
|
st.session_state.vectorstore = get_vectorstore() |
|
|
|
st.markdown("<div style='padding:10px 15px;background-color:rgba(40,167,69,0.2);border-radius:5px;border-left:5px solid rgba(40,167,69,0.8);'>Clinical knowledge base loaded successfully!</div>", unsafe_allow_html=True) |
|
except Exception as e: |
|
st.error(f"Error loading knowledge base: {str(e)}") |
|
st.session_state.vectorstore = None |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stApp {max-width: 1200px; margin: 0 auto;} |
|
.css-18e3th9 {padding-top: 2rem;} |
|
.stButton>button {background-color: #3498db; color: white;} |
|
.stButton>button:hover {background-color: #2980b9;} |
|
.chat-message {border-radius: 10px; padding: 10px; margin-bottom: 10px;} |
|
.chat-message-user {background-color: rgba(52, 152, 219, 0.2); color: inherit;} |
|
.chat-message-assistant {background-color: rgba(240, 240, 240, 0.2); color: inherit;} |
|
.source-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-bottom: 10px; border-left: 5px solid #3498db;} |
|
.metrics-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-top: 20px;} |
|
.features-container {display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; margin-top: 30px;} |
|
.feature-item {flex: 1 1 calc(50% - 20px); min-width: 300px; display: flex; align-items: center; padding: 20px; border-radius: 10px; background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); transition: transform 0.3s, box-shadow 0.3s; border: 1px solid rgba(255, 255, 255, 0.1);} |
|
.feature-item:hover {transform: translateY(-5px); box-shadow: 0 10px 20px rgba(0, 0, 0, 0.1);} |
|
.feature-icon {width: 60px; height: 60px; border-radius: 50%; background: linear-gradient(135deg, #3498db, #2980b9); display: flex; align-items: center; justify-content: center; margin-right: 20px; box-shadow: 0 5px 15px rgba(52, 152, 219, 0.3);} |
|
.feature-icon i {font-size: 24px; color: white;} |
|
.feature-content {flex: 1;} |
|
.feature-content h3 {margin-top: 0; margin-bottom: 10px; color: inherit;} |
|
.feature-content p {margin: 0; font-size: 0.9em; color: inherit; opacity: 0.8;} |
|
.input-container {margin-bottom: 20px; padding: 15px; border-radius: 10px; background-color: rgba(255, 255, 255, 0.05); border: 1px solid rgba(255, 255, 255, 0.1);} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
if 'page' not in st.session_state: |
|
st.session_state.page = 'cover' |
|
|
|
|
|
with st.sidebar: |
|
st.image("https://img.icons8.com/color/96/000000/caduceus.png", width=80) |
|
st.title("DiReCT") |
|
st.markdown("### Diagnostic Reasoning for Clinical Text") |
|
st.markdown("---") |
|
|
|
if st.button("Home", key="home_btn"): |
|
st.session_state.page = 'cover' |
|
if st.button("Diagnostic Assistant", key="assistant_btn"): |
|
st.session_state.page = 'chat' |
|
if st.button("About", key="about_btn"): |
|
st.session_state.page = 'about' |
|
|
|
st.markdown("---") |
|
st.markdown("### Model Information") |
|
st.markdown("**Embedding Model:** Bio_ClinicalBERT") |
|
st.markdown("**LLM:** Llama-3.3-70B") |
|
st.markdown("**Vector Store:** FAISS") |
|
|
|
|
|
if st.session_state.page == 'cover': |
|
|
|
col1, col2 = st.columns([2, 1]) |
|
with col1: |
|
st.markdown("<h1 style='font-size:3.5em;'>DiReCT</h1>", unsafe_allow_html=True) |
|
st.markdown("<h2 style='font-size:1.8em;color:#3498db;'>Diagnostic Reasoning for Clinical Text</h2>", unsafe_allow_html=True) |
|
st.markdown("""<p style='font-size:1.2em;'>A powerful RAG-based clinical diagnostic assistant that leverages the MIMIC-IV-Ext dataset to provide accurate medical insights and diagnostic reasoning.</p>""", unsafe_allow_html=True) |
|
|
|
st.markdown("""<br>""", unsafe_allow_html=True) |
|
if st.button("Get Started", key="get_started"): |
|
st.session_state.page = 'chat' |
|
st.rerun() |
|
|
|
with col2: |
|
|
|
st.markdown(""" |
|
<div style='display:flex;justify-content:center;align-items:center;height:100%;'> |
|
<img src="https://img.icons8.com/color/240/000000/healthcare-and-medical.png" style='max-width:90%;'> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("<br><br>", unsafe_allow_html=True) |
|
st.markdown("<h2 style='text-align:center;'>Key Features</h2>", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
st.markdown(""" |
|
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); |
|
padding: 20px; border-radius: 10px; height: 100%; |
|
border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;"> |
|
<h3>π Intelligent Retrieval</h3> |
|
<p>Finds the most relevant clinical information from the MIMIC-IV-Ext dataset</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with col2: |
|
st.markdown(""" |
|
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); |
|
padding: 20px; border-radius: 10px; height: 100%; |
|
border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;"> |
|
<h3>π§ Advanced Reasoning</h3> |
|
<p>Applies clinical knowledge to generate accurate diagnostic insights</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with col1: |
|
st.markdown(""" |
|
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); |
|
padding: 20px; border-radius: 10px; height: 100%; |
|
border: 1px solid rgba(255, 255, 255, 0.1);"> |
|
<h3>π Source Transparency</h3> |
|
<p>Provides references to all clinical sources used in generating responses</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with col2: |
|
st.markdown(""" |
|
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); |
|
padding: 20px; border-radius: 10px; height: 100%; |
|
border: 1px solid rgba(255, 255, 255, 0.1);"> |
|
<h3>π Dark/Light Theme Compatible</h3> |
|
<p>Optimized interface that works seamlessly in both dark and light themes</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
elif st.session_state.page == 'chat': |
|
|
|
if 'user_input' not in st.session_state: |
|
st.session_state.user_input = "" |
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
st.markdown("<h1>Clinical Diagnostic Assistant</h1>", unsafe_allow_html=True) |
|
with col2: |
|
|
|
if st.button("ποΈ Clear Chat"): |
|
st.session_state.chat_history = [] |
|
st.session_state.user_input = "" |
|
st.rerun() |
|
|
|
st.markdown("Ask any clinical diagnostic question and get insights based on medical knowledge and patient cases.") |
|
|
|
|
|
with st.container(): |
|
st.markdown("<div class='input-container'>", unsafe_allow_html=True) |
|
user_input = st.text_area("Ask a clinical question:", st.session_state.user_input, height=100, key="question_input") |
|
col1, col2 = st.columns([1, 5]) |
|
with col1: |
|
submit_button = st.button("Submit") |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
chat_container = st.container() |
|
|
|
|
|
if submit_button and user_input: |
|
if st.session_state.vectorstore is None: |
|
st.error("Knowledge base not loaded. Please refresh the page and try again.") |
|
else: |
|
with st.spinner("Analyzing clinical data..."): |
|
try: |
|
|
|
time.sleep(0.5) |
|
|
|
|
|
response = run_rag_chat(user_input, st.session_state.vectorstore) |
|
response["retriever"] = st.session_state.vectorstore.as_retriever() |
|
|
|
|
|
st.session_state.chat_history = [(user_input, response)] |
|
|
|
|
|
st.session_state.user_input = "" |
|
|
|
|
|
st.rerun() |
|
except Exception as e: |
|
st.error(f"Error processing query: {str(e)}") |
|
|
|
|
|
with chat_container: |
|
for i, (query, response) in enumerate(st.session_state.chat_history): |
|
st.markdown(f"<div class='chat-message chat-message-user'><b>π§ββοΈ You:</b> {query}</div>", unsafe_allow_html=True) |
|
|
|
st.markdown(f"<div class='chat-message chat-message-assistant'><b>π©Ί DiReCT:</b> {response['answer']}</div>", unsafe_allow_html=True) |
|
|
|
with st.expander("View Sources"): |
|
for doc in response["context"]: |
|
st.markdown(f"<div class='source-box'>" |
|
f"<b>Source:</b> {Path(doc.metadata['source']).stem}<br>" |
|
f"<b>Type:</b> {doc.metadata['type']}<br>" |
|
f"<b>Content:</b> {doc.page_content[:300]}...</div>", |
|
unsafe_allow_html=True) |
|
|
|
|
|
try: |
|
eval_scores = evaluate_rag_response(response, embeddings) |
|
with st.expander("View Evaluation Metrics"): |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.metric("Hit Rate (Top-3)", f"{eval_scores['hit_rate']:.2f}") |
|
with col2: |
|
st.metric("Faithfulness", f"{eval_scores['faithfulness']:.2f}") |
|
except Exception as e: |
|
st.warning(f"Evaluation metrics unavailable: {str(e)}") |
|
|
|
|
|
elif st.session_state.page == 'about': |
|
st.markdown("<h1>About DiReCT</h1>", unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
### Project Overview |
|
|
|
DiReCT (Diagnostic Reasoning for Clinical Text) is a Retrieval-Augmented Generation (RAG) system designed to assist medical professionals with diagnostic reasoning based on clinical notes and medical knowledge. |
|
|
|
### Data Sources |
|
|
|
This application uses the MIMIC-IV-Ext dataset, which contains de-identified clinical notes and medical records. The system processes: |
|
|
|
- Diagnostic flowcharts |
|
- Patient cases |
|
- Clinical guidelines |
|
|
|
### Technical Implementation |
|
|
|
- **Embedding Model**: Bio_ClinicalBERT for domain-specific text understanding |
|
- **Vector Database**: FAISS for efficient similarity search |
|
- **LLM**: Llama-3.3-70B for generating medically accurate responses |
|
- **Framework**: Built with LangChain and Streamlit |
|
|
|
### Evaluation Metrics |
|
|
|
The system evaluates responses using: |
|
|
|
- **Hit Rate**: Measures how many relevant documents were retrieved |
|
- **Faithfulness**: Measures how well the response aligns with the retrieved context |
|
|
|
### Ethical Considerations |
|
|
|
This system is designed as a clinical decision support tool and not as a replacement for professional medical judgment. All patient data used has been properly de-identified in compliance with healthcare privacy regulations. |
|
""") |
|
|
|
st.markdown("<br>", unsafe_allow_html=True) |
|
st.markdown("### Developers") |
|
st.markdown("This project was developed as part of an academic assignment on RAG systems for clinical applications.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|