Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import transforms, models | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
num_classes_school = 26 | |
num_classes_type = 10 | |
model_path = hf_hub_download( | |
repo_id="Irina1402/mobilnetv3-painting-classification", | |
filename="model.pth" | |
) | |
model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT) | |
num_features = model.classifier[0].in_features | |
model.classifier = nn.Sequential( | |
nn.Linear(num_features, 512), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(512, num_classes_school + num_classes_type) | |
) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model = model.to(device) | |
model.eval() | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
school_labels = [ | |
"American", "Austrian", "Belgian", "Bohemian", "Catalan", "Danish", "Dutch", "English", "Finnish", | |
"Flemish", "French", "German", "Greek", "Hungarian", "Irish", "Italian", "Netherlandish", "Norwegian", | |
"Other", "Polish", "Portuguese", "Russian", "Scottish", "Spanish", "Swedish", "Swiss" | |
] | |
type_labels = [ | |
"genre", "historical", "interior", "landscape", "mythological", "other", | |
"portrait", "religious", "still-life", "study" | |
] | |
def classify_image(image: Image.Image): | |
"""Classify the uploaded image and return type and school predictions.""" | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
school_output = output[:, :num_classes_school] | |
type_output = output[:, num_classes_school:] | |
school_prediction = torch.argmax(school_output).item() | |
type_prediction = torch.argmax(type_output).item() | |
return { | |
"school": school_labels[school_prediction], | |
"type": type_labels[type_prediction] | |
} | |