import torch from PIL import Image from .preprocess import preprocess_image from .utils import load_model def predict_with_model(model, inputs): """Runs inference and returns the predicted class.""" model.eval() # Ensure the model is in evaluation mode with torch.no_grad(): # Disable gradient calculation outputs = model(**inputs) logits = outputs.logits predicted_class = logits.argmax(dim=-1).item() # Get predicted class index return predicted_class def predict(image_path): """Loads an image, preprocesses it, runs the model, and returns the prediction.""" image = Image.open(image_path).convert("RGB") inputs = preprocess_image(image) # Load model model = load_model() # Ensure inputs are on the same device as the model device = model.device inputs = {key: tensor.to(device) for key, tensor in inputs.items()} return predict_with_model(model, inputs)