from keras import ( saving, preprocessing, applications ) import fastapi from fastapi import UploadFile, File, HTTPException from PIL import Image import io import time import numpy as np app = fastapi.FastAPI() model = saving.load_model("hf://Yumeng-Liu/trash-classifier") CLASSES = ['biological', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] THRESHOLD = 5.0e-1 def get_prediction(img: Image) -> str: img = img.resize((224, 224)) img_array = preprocessing.image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) # Add an extra dimension to match the model's input shape img_array = applications.mobilenet_v2.preprocess_input(img_array) prediction_array = model.predict(img_array) predicted_class_idx = np.argmax(prediction_array[0]) if prediction_array[0][predicted_class_idx] > THRESHOLD: prediction = CLASSES[predicted_class_idx] else: prediction = "none" return prediction @app.get("/") def read_root(): return {"Hello": "World"} @app.post("/predict-image") async def predict(received_image: UploadFile = File(...)): try: contents = received_image.file.read() # Open the binary data as an image image = Image.open(io.BytesIO(contents)) print("Image received") # You can now work with the `image` object print(image.format, image.size, image.mode) # Example: JPEG (1920, 1080) RGB print("") prediction_result = get_prediction(image) print(prediction_result) # Perform further processing, e.g., save it, analyze it, etc. return { "result": prediction_result } except Exception as e: print(e) raise HTTPException(status_code=500, detail='Something went wrong') finally: received_image.file.close() if __name__ == "__main__": print("Starting app") while True: time.sleep(10)