from fastapi import FastAPI, UploadFile, File from transformers import AutoProcessor, AutoConfig, AutoModelForImageClassification from PIL import Image import torch import io app = FastAPI() # Load processor and config processor = AutoProcessor.from_pretrained("quantized_model") config = AutoConfig.from_pretrained("dima806/deepfake_vs_real_image_detection") # Load model architecture and quantized weights model = AutoModelForImageClassification.from_config(config) model_quantized = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) model_quantized.load_state_dict(torch.load("quantized_model/model_quantized.pt", map_location="cpu")) model_quantized.eval() @app.post("/predict") async def predict(file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): logits = model_quantized(**inputs).logits predicted_idx = logits.argmax(-1).item() confidence = logits.softmax(-1)[0][predicted_idx].item() label = model_quantized.config.id2label[predicted_idx] return {"label": label, "confidence": confidence}