ankitkupadhyay's picture
Update app.py
0b14df6 verified
raw
history blame contribute delete
3.82 kB
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr
# Model definition and setup
class VisionLanguageModel(nn.Module):
def __init__(self):
super(VisionLanguageModel, self).__init__()
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.language_model = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(
self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
2 # Number of classes: benign or malignant
)
def forward(self, input_ids, attention_mask, pixel_values):
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_pooled_output = vision_outputs.pooler_output
language_outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask
)
language_pooled_output = language_outputs.pooler_output
combined_features = torch.cat(
(vision_pooled_output, language_pooled_output),
dim=1
)
logits = self.classifier(combined_features)
return logits
model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
def predict(image, text_input):
image = feature_extractor(images=image, return_tensors="pt").pixel_values
encoding = tokenizer(
text_input,
add_special_tokens=True,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = model(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
pixel_values=image
)
_, prediction = torch.max(outputs, dim=1)
return prediction.item() # 1 for Malignant, 0 for Benign
# Enhanced UI with black text
with gr.Blocks(css="""
body {
color: black;
}
.benign, .malignant {
background-color: white;
border: 1px solid lightgray;
padding: 10px;
border-radius: 5px;
color: black;
}
.benign.correct, .malignant.correct {
background-color: lightgreen;
color: black;
}
""") as demo:
gr.Markdown(
"""
# 🩺 SKIN LESION CLASSIFICATION
Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
with gr.Column(scale=1):
gr.Markdown("## PREDICTION RESULTS")
benign_output = gr.HTML("<div class='benign'>Benign</div>")
malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")
def display_prediction(image, text_input):
prediction = predict(image, text_input)
benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
return benign_html, malignant_html
# Submit button and prediction outputs
submit_btn = gr.Button("Get Prediction")
submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])
demo.launch()