painting-assistant / classification.py
Irina
Add application file
fb1f781
raw
history blame contribute delete
2.02 kB
import torch
import torch.nn as nn
from torchvision import transforms, models
from huggingface_hub import hf_hub_download
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes_school = 26
num_classes_type = 10
model_path = hf_hub_download(
repo_id="Irina1402/mobilnetv3-painting-classification",
filename="model.pth"
)
model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)
num_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes_school + num_classes_type)
)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
school_labels = [
"American", "Austrian", "Belgian", "Bohemian", "Catalan", "Danish", "Dutch", "English", "Finnish",
"Flemish", "French", "German", "Greek", "Hungarian", "Irish", "Italian", "Netherlandish", "Norwegian",
"Other", "Polish", "Portuguese", "Russian", "Scottish", "Spanish", "Swedish", "Swiss"
]
type_labels = [
"genre", "historical", "interior", "landscape", "mythological", "other",
"portrait", "religious", "still-life", "study"
]
def classify_image(image: Image.Image):
"""Classify the uploaded image and return type and school predictions."""
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
school_output = output[:, :num_classes_school]
type_output = output[:, num_classes_school:]
school_prediction = torch.argmax(school_output).item()
type_prediction = torch.argmax(type_output).item()
return {
"school": school_labels[school_prediction],
"type": type_labels[type_prediction]
}