Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
4 |
+
|
5 |
+
# 1) Classification model
|
6 |
+
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model")
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")
|
8 |
+
|
9 |
+
# 2) Unified NER pipeline
|
10 |
+
ner = pipeline(
|
11 |
+
"ner",
|
12 |
+
model="d4data/biomedical-ner-all",
|
13 |
+
tokenizer="d4data/biomedical-ner-all",
|
14 |
+
aggregation_strategy="simple"
|
15 |
+
)
|
16 |
+
|
17 |
+
# 3) Tight tag sets
|
18 |
+
SYMPTOM_TAGS = {"sign_symptom", "symptom"}
|
19 |
+
DISEASE_TAGS = {"disease_disorder"}
|
20 |
+
MED_TAGS = {"medication", "administration", "therapeutic_procedure"}
|
21 |
+
|
22 |
+
# 4) Helper: drop <3โchar & dedupe
|
23 |
+
def dedupe_and_filter(tokens):
|
24 |
+
seen, out = set(), []
|
25 |
+
for tok in tokens:
|
26 |
+
w = tok.strip()
|
27 |
+
if len(w) < 3:
|
28 |
+
continue
|
29 |
+
lw = w.lower()
|
30 |
+
if lw not in seen:
|
31 |
+
seen.add(lw)
|
32 |
+
out.append(w)
|
33 |
+
return out
|
34 |
+
|
35 |
+
def classify_adr(text: str):
|
36 |
+
print("๐ [DEBUG] Running classify_adr", flush=True)
|
37 |
+
|
38 |
+
# Clean
|
39 |
+
clean = text.strip().replace("nan", "").replace(" ", " ")
|
40 |
+
print("๐ [DEBUG] clean[:50]:", clean[:50], "...", flush=True)
|
41 |
+
|
42 |
+
# Severity
|
43 |
+
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
44 |
+
with torch.no_grad():
|
45 |
+
logits = model(**inputs).logits
|
46 |
+
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
|
47 |
+
|
48 |
+
# Raw NER
|
49 |
+
ents = ner(clean)
|
50 |
+
print("๐ [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True)
|
51 |
+
|
52 |
+
# 1) Build & merge spans by offsets
|
53 |
+
spans = []
|
54 |
+
for ent in ents:
|
55 |
+
grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
|
56 |
+
if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]:
|
57 |
+
spans[-1]["end"] = max(spans[-1]["end"], end)
|
58 |
+
spans[-1]["score"] = max(spans[-1]["score"], score)
|
59 |
+
else:
|
60 |
+
spans.append({"group": grp, "start": start, "end": end, "score": score})
|
61 |
+
print("๐ [DEBUG] merged spans:", spans, flush=True)
|
62 |
+
|
63 |
+
# 2) Extend med spans out to full word
|
64 |
+
for s in spans:
|
65 |
+
if s["group"] in MED_TAGS:
|
66 |
+
st, en = s["start"], s["end"]
|
67 |
+
# extend forward while alphabetic
|
68 |
+
while en < len(clean) and clean[en].isalpha():
|
69 |
+
en += 1
|
70 |
+
s["end"] = en
|
71 |
+
|
72 |
+
# 3) Filter by confidence โฅ0.6
|
73 |
+
spans = [s for s in spans if s["score"] >= 0.6]
|
74 |
+
print("๐ [DEBUG] postโfilter spans:", spans, flush=True)
|
75 |
+
|
76 |
+
# 4) Extract text
|
77 |
+
tokens = [clean[s["start"]:s["end"]] for s in spans]
|
78 |
+
print("๐ [DEBUG] tokens:", tokens, flush=True)
|
79 |
+
|
80 |
+
# Bucket & dedupe
|
81 |
+
symptoms = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS])
|
82 |
+
diseases = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS])
|
83 |
+
medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS])
|
84 |
+
|
85 |
+
# Interpretation
|
86 |
+
if probs[1] > 0.9:
|
87 |
+
comment = "โ High confidence this is a severe ADR."
|
88 |
+
elif probs[1] > 0.5:
|
89 |
+
comment = "โ ๏ธ Borderline case โ may be severe."
|
90 |
+
else:
|
91 |
+
comment = "โ
Likely not severe."
|
92 |
+
|
93 |
+
return (
|
94 |
+
f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
|
95 |
+
"\n".join(symptoms) or "None detected",
|
96 |
+
"\n".join(diseases) or "None detected",
|
97 |
+
"\n".join(medications) or "None detected",
|
98 |
+
comment
|
99 |
+
)
|
100 |
+
|
101 |
+
# 5) Gradio UI
|
102 |
+
demo = gr.Interface(
|
103 |
+
fn=classify_adr,
|
104 |
+
inputs=gr.Textbox(lines=4, label="ADR Description"),
|
105 |
+
outputs=[
|
106 |
+
gr.Textbox(label="Predicted Probabilities"),
|
107 |
+
gr.Textbox(label="Symptoms"),
|
108 |
+
gr.Textbox(label="Diseases or Conditions"),
|
109 |
+
gr.Textbox(label="Medications"),
|
110 |
+
gr.Textbox(label="Interpretation"),
|
111 |
+
],
|
112 |
+
title="ADR Severity & NER Classifier",
|
113 |
+
description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.",
|
114 |
+
allow_flagging="never"
|
115 |
+
)
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
demo.launch()
|