|
import os |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
import shap |
|
from shap.maskers import Text |
|
from shap.explainers import Permutation |
|
|
|
|
|
device = torch.device("cpu") |
|
print(f"β
Running on device: {device}") |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model").to(device).eval() |
|
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model") |
|
|
|
|
|
ner = pipeline( |
|
"ner", |
|
model="d4data/biomedical-ner-all", |
|
tokenizer="d4data/biomedical-ner-all", |
|
aggregation_strategy="simple", |
|
device=-1 |
|
) |
|
|
|
|
|
clf_pipeline = pipeline( |
|
"text-classification", |
|
model=model, |
|
tokenizer=tokenizer, |
|
top_k=None, |
|
device=-1 |
|
) |
|
|
|
def shap_predict(texts): |
|
texts = [str(t) for t in texts] |
|
results = clf_pipeline(texts, truncation=True, padding=True, max_length=512) |
|
scores = [] |
|
for i, text in enumerate(texts): |
|
if isinstance(results[i], dict): |
|
scores.append([1 - results[i]['score'], results[i]['score']]) |
|
else: |
|
scores.append([entry['score'] for entry in results[i]]) |
|
return np.array(scores) |
|
|
|
masker = Text(tokenizer) |
|
explainer = Permutation(shap_predict, masker, output_names=["Not Severe", "Severe"]) |
|
|
|
SYMPTOM_TAGS = {"sign_symptom", "symptom"} |
|
DISEASE_TAGS = {"disease_disorder"} |
|
MED_TAGS = {"medication", "administration", "therapeutic_procedure"} |
|
|
|
def dedupe_and_filter(tokens): |
|
seen, out = set(), [] |
|
for tok in tokens: |
|
w = tok.strip() |
|
if len(w) < 3: |
|
continue |
|
lw = w.lower() |
|
if lw not in seen: |
|
seen.add(lw) |
|
out.append(w) |
|
return out |
|
|
|
def classify_adr(text, explain=False): |
|
clean = text.strip().replace("nan", "").replace(" ", " ")[:512] |
|
|
|
|
|
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = torch.softmax(logits, dim=1)[0].cpu().numpy() |
|
|
|
|
|
ents = ner(clean) |
|
spans = [] |
|
for ent in ents: |
|
grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0) |
|
if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"] + 1: |
|
spans[-1]["end"] = max(spans[-1]["end"], end) |
|
spans[-1]["score"] = max(spans[-1]["score"], score) |
|
else: |
|
spans.append({"group": grp, "start": start, "end": end, "score": score}) |
|
|
|
for s in spans: |
|
if s["group"] in MED_TAGS: |
|
en = s["end"] |
|
while en < len(clean) and clean[en].isalpha(): |
|
en += 1 |
|
s["end"] = en |
|
|
|
spans = [s for s in spans if s["score"] >= 0.6] |
|
|
|
tokens = [] |
|
for s in spans: |
|
chunk = clean[s["start"]:s["end"]].strip() |
|
if len(chunk) >= 3: |
|
tokens.append((chunk, s["group"])) |
|
|
|
symptoms = dedupe_and_filter([t for t, g in tokens if g in SYMPTOM_TAGS]) |
|
diseases = dedupe_and_filter([t for t, g in tokens if g in DISEASE_TAGS]) |
|
medications = dedupe_and_filter([t for t, g in tokens if g in MED_TAGS]) |
|
|
|
if probs[1] > 0.9: |
|
comment = "β High confidence this is a severe ADR." |
|
elif probs[1] > 0.5: |
|
comment = "β οΈ Borderline case β may be severe." |
|
else: |
|
comment = "β
Likely not severe." |
|
|
|
|
|
shap_path = None |
|
if explain: |
|
try: |
|
shap_values = explainer([clean], max_evals=min(400, len(clean.split()) * 5)) |
|
plt.figure() |
|
shap.plots.bar(shap_values[0], show=False) |
|
shap_path = "/tmp/shap_expl.png" |
|
plt.savefig(shap_path, bbox_inches="tight") |
|
plt.close() |
|
except Exception as e: |
|
print(f"[SHAP Error] {e}") |
|
shap_path = None |
|
|
|
return ( |
|
f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}", |
|
"\n".join(symptoms) or "None detected", |
|
"\n".join(diseases) or "None detected", |
|
"\n".join(medications) or "None detected", |
|
comment, |
|
shap_path |
|
) |
|
|
|
demo = gr.Interface( |
|
fn=classify_adr, |
|
inputs=[ |
|
gr.Textbox(lines=5, label="ADR Description"), |
|
gr.Checkbox(label="Generate SHAP Explanation (VERY slow)", value=False) |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Predicted Probabilities"), |
|
gr.Textbox(label="Symptoms"), |
|
gr.Textbox(label="Diseases or Conditions"), |
|
gr.Textbox(label="Medications"), |
|
gr.Textbox(label="Interpretation"), |
|
gr.Image(label="SHAP Explanation") |
|
], |
|
title="ADR Severity & NER Classifier 2", |
|
description="Paste an ADR description to classify severity, extract symptoms, diseases, medications, and visualize SHAP explanations.", |
|
allow_flagging="never" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |