calerio commited on
Commit
8a41839
ยท
verified ยท
1 Parent(s): ce71b36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
+
5
+ # 1) Classification model
6
+ model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model")
7
+ tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")
8
+
9
+ # 2) Unified NER pipeline
10
+ ner = pipeline(
11
+ "ner",
12
+ model="d4data/biomedical-ner-all",
13
+ tokenizer="d4data/biomedical-ner-all",
14
+ aggregation_strategy="simple"
15
+ )
16
+
17
+ # 3) Tight tag sets
18
+ SYMPTOM_TAGS = {"sign_symptom", "symptom"}
19
+ DISEASE_TAGS = {"disease_disorder"}
20
+ MED_TAGS = {"medication", "administration", "therapeutic_procedure"}
21
+
22
+ # 4) Helper: drop <3โ€‘char & dedupe
23
+ def dedupe_and_filter(tokens):
24
+ seen, out = set(), []
25
+ for tok in tokens:
26
+ w = tok.strip()
27
+ if len(w) < 3:
28
+ continue
29
+ lw = w.lower()
30
+ if lw not in seen:
31
+ seen.add(lw)
32
+ out.append(w)
33
+ return out
34
+
35
+ def classify_adr(text: str):
36
+ print("๐Ÿ” [DEBUG] Running classify_adr", flush=True)
37
+
38
+ # Clean
39
+ clean = text.strip().replace("nan", "").replace(" ", " ")
40
+ print("๐Ÿ” [DEBUG] clean[:50]:", clean[:50], "...", flush=True)
41
+
42
+ # Severity
43
+ inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512)
44
+ with torch.no_grad():
45
+ logits = model(**inputs).logits
46
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
47
+
48
+ # Raw NER
49
+ ents = ner(clean)
50
+ print("๐Ÿ” [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True)
51
+
52
+ # 1) Build & merge spans by offsets
53
+ spans = []
54
+ for ent in ents:
55
+ grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
56
+ if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]:
57
+ spans[-1]["end"] = max(spans[-1]["end"], end)
58
+ spans[-1]["score"] = max(spans[-1]["score"], score)
59
+ else:
60
+ spans.append({"group": grp, "start": start, "end": end, "score": score})
61
+ print("๐Ÿ” [DEBUG] merged spans:", spans, flush=True)
62
+
63
+ # 2) Extend med spans out to full word
64
+ for s in spans:
65
+ if s["group"] in MED_TAGS:
66
+ st, en = s["start"], s["end"]
67
+ # extend forward while alphabetic
68
+ while en < len(clean) and clean[en].isalpha():
69
+ en += 1
70
+ s["end"] = en
71
+
72
+ # 3) Filter by confidence โ‰ฅ0.6
73
+ spans = [s for s in spans if s["score"] >= 0.6]
74
+ print("๐Ÿ” [DEBUG] postโ€‘filter spans:", spans, flush=True)
75
+
76
+ # 4) Extract text
77
+ tokens = [clean[s["start"]:s["end"]] for s in spans]
78
+ print("๐Ÿ” [DEBUG] tokens:", tokens, flush=True)
79
+
80
+ # Bucket & dedupe
81
+ symptoms = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS])
82
+ diseases = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS])
83
+ medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS])
84
+
85
+ # Interpretation
86
+ if probs[1] > 0.9:
87
+ comment = "โ— High confidence this is a severe ADR."
88
+ elif probs[1] > 0.5:
89
+ comment = "โš ๏ธ Borderline case โ€” may be severe."
90
+ else:
91
+ comment = "โœ… Likely not severe."
92
+
93
+ return (
94
+ f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
95
+ "\n".join(symptoms) or "None detected",
96
+ "\n".join(diseases) or "None detected",
97
+ "\n".join(medications) or "None detected",
98
+ comment
99
+ )
100
+
101
+ # 5) Gradio UI
102
+ demo = gr.Interface(
103
+ fn=classify_adr,
104
+ inputs=gr.Textbox(lines=4, label="ADR Description"),
105
+ outputs=[
106
+ gr.Textbox(label="Predicted Probabilities"),
107
+ gr.Textbox(label="Symptoms"),
108
+ gr.Textbox(label="Diseases or Conditions"),
109
+ gr.Textbox(label="Medications"),
110
+ gr.Textbox(label="Interpretation"),
111
+ ],
112
+ title="ADR Severity & NER Classifier",
113
+ description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.",
114
+ allow_flagging="never"
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch()