import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline # 1) Classification model model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model") tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model") # 2) Unified NER pipeline ner = pipeline( "ner", model="d4data/biomedical-ner-all", tokenizer="d4data/biomedical-ner-all", aggregation_strategy="simple" ) # 3) Tight tag sets SYMPTOM_TAGS = {"sign_symptom", "symptom"} DISEASE_TAGS = {"disease_disorder"} MED_TAGS = {"medication", "administration", "therapeutic_procedure"} # 4) Helper: drop <3‑char & dedupe def dedupe_and_filter(tokens): seen, out = set(), [] for tok in tokens: w = tok.strip() if len(w) < 3: continue lw = w.lower() if lw not in seen: seen.add(lw) out.append(w) return out def classify_adr(text: str): print("πŸ” [DEBUG] Running classify_adr", flush=True) # Clean clean = text.strip().replace("nan", "").replace(" ", " ") print("πŸ” [DEBUG] clean[:50]:", clean[:50], "...", flush=True) # Severity inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=1)[0].cpu().numpy() # Raw NER ents = ner(clean) print("πŸ” [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True) # 1) Build & merge spans by offsets spans = [] for ent in ents: grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0) if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]: spans[-1]["end"] = max(spans[-1]["end"], end) spans[-1]["score"] = max(spans[-1]["score"], score) else: spans.append({"group": grp, "start": start, "end": end, "score": score}) print("πŸ” [DEBUG] merged spans:", spans, flush=True) # 2) Extend med spans out to full word for s in spans: if s["group"] in MED_TAGS: st, en = s["start"], s["end"] # extend forward while alphabetic while en < len(clean) and clean[en].isalpha(): en += 1 s["end"] = en # 3) Filter by confidence β‰₯0.6 spans = [s for s in spans if s["score"] >= 0.6] print("πŸ” [DEBUG] post‑filter spans:", spans, flush=True) # 4) Extract text tokens = [clean[s["start"]:s["end"]] for s in spans] print("πŸ” [DEBUG] tokens:", tokens, flush=True) # Bucket & dedupe symptoms = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS]) diseases = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS]) medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS]) # Interpretation if probs[1] > 0.9: comment = "❗ High confidence this is a severe ADR." elif probs[1] > 0.5: comment = "⚠️ Borderline case β€” may be severe." else: comment = "βœ… Likely not severe." return ( f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}", "\n".join(symptoms) or "None detected", "\n".join(diseases) or "None detected", "\n".join(medications) or "None detected", comment ) # 5) Gradio UI demo = gr.Interface( fn=classify_adr, inputs=gr.Textbox(lines=4, label="ADR Description"), outputs=[ gr.Textbox(label="Predicted Probabilities"), gr.Textbox(label="Symptoms"), gr.Textbox(label="Diseases or Conditions"), gr.Textbox(label="Medications"), gr.Textbox(label="Interpretation"), ], title="ADR Severity & NER Classifier", description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.", allow_flagging="never" ) if __name__ == "__main__": demo.launch()