MobileNet_Fire / app.py
Suhani-2407's picture
Update app.py
b291d86 verified
raw
history blame contribute delete
2.88 kB
import os
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO
import base64
# Load the model when the script is loaded
model = tf.keras.models.load_model("MobileNet_model.h5")
# Your specific class labels
class_labels = {
0: "Fake",
1: "Low",
2: "Medium",
3: "High"
}
def preprocess_image(image):
"""Preprocess the image for model prediction"""
# Resize image to model's expected input dimensions
image = image.resize((128, 128))
# Convert to numpy array and normalize
img_array = np.array(image) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_image(image):
"""Make prediction on a single image"""
img_array = preprocess_image(image)
predictions = model.predict(img_array)
predicted_class_idx = np.argmax(predictions)
predicted_class = class_labels[predicted_class_idx]
confidence = float(np.max(predictions))
return {
"predicted_class": predicted_class,
"confidence": confidence,
"class_probabilities": {class_labels[i]: float(prob) for i, prob in enumerate(predictions[0])}
}
def inference(data):
"""
Inference function for Hugging Face API
data can be:
- File path (string)
- URL string
- Base64 encoded image
- Raw image bytes
- Dict with image key containing any of the above
"""
# Handle different input formats
if isinstance(data, dict) and "image" in data:
data = data["image"]
# Handle local file path
if isinstance(data, str) and os.path.isfile(data):
image = Image.open(data)
# Handle URL (Hugging Face will download the image)
elif isinstance(data, str) and (data.startswith("http://") or data.startswith("https://")):
from urllib.request import urlopen
with urlopen(data) as response:
image_bytes = response.read()
image = Image.open(BytesIO(image_bytes))
# Handle base64 encoded image
elif isinstance(data, str) and data.startswith("data:image"):
base64_data = data.split(",")[1]
image_bytes = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_bytes))
# Handle raw image bytes
elif isinstance(data, bytes):
image = Image.open(BytesIO(data))
# Convert RGBA to RGB if needed
if image.mode == "RGBA":
image = image.convert("RGB")
# Make prediction
return predict_image(image)
# For local testing
if __name__ == "__main__":
# Example of using a file path
test_image_path = "path/to/test/image.jpg"
if os.path.exists(test_image_path):
result = inference(test_image_path)
print(f"Predicted class: {result['predicted_class']}")
print(f"Confidence: {result['confidence']:.4f}")