File size: 944 Bytes
ab8b628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from PIL import Image
from .preprocess import preprocess_image
from .utils import load_model


def predict_with_model(model, inputs):
    """Runs inference and returns the predicted class."""
    model.eval()  # Ensure the model is in evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = logits.argmax(dim=-1).item()  # Get predicted class index
    return predicted_class


def predict(image_path):
    """Loads an image, preprocesses it, runs the model, and returns the prediction."""
    image = Image.open(image_path).convert("RGB")
    inputs = preprocess_image(image)

    # Load model
    model = load_model()

    # Ensure inputs are on the same device as the model
    device = model.device
    inputs = {key: tensor.to(device) for key, tensor in inputs.items()}

    return predict_with_model(model, inputs)