Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from ultralytics import YOLO | |
from PIL import Image | |
import cv2 | |
import mediapipe as mp | |
import numpy as np | |
import io | |
import tempfile | |
import os | |
from fastapi.middleware.cors import CORSMiddleware | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, specify your allowed domains | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Health Check Endpoint | |
def root(): | |
""" | |
Health check endpoint to verify that the API is running. | |
""" | |
return {"message": "YOLO11 Emotion Detection API is live!"} | |
# Load your custom YOLO emotion detection model | |
try: | |
emotion_model = YOLO("model/yolo11m_affectnet_best.pt") | |
except Exception as e: | |
raise Exception(f"Error loading emotion model: {e}") | |
def detect_emotions(image): | |
""" | |
Given an OpenCV BGR image, detect faces using Mediapipe and perform emotion detection | |
on each face crop using the YOLO model. | |
Returns a list of detections with bounding box and emotion details. | |
""" | |
height, width, _ = image.shape | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
detections_list = [] | |
mp_face_detection = mp.solutions.face_detection | |
with mp_face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5) as face_detection: | |
results = face_detection.process(image_rgb) | |
if results.detections: | |
for detection in results.detections: | |
bbox = detection.location_data.relative_bounding_box | |
x_min = int(bbox.xmin * width) | |
y_min = int(bbox.ymin * height) | |
box_width = int(bbox.width * width) | |
box_height = int(bbox.height * height) | |
x_max = x_min + box_width | |
y_max = y_min + box_height | |
x_min = max(0, x_min) | |
y_min = max(0, y_min) | |
x_max = min(width, x_max) | |
y_max = min(height, y_max) | |
face_crop = image[y_min:y_max, x_min:x_max] | |
if face_crop.size == 0: | |
continue | |
face_crop_rgb = cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB) | |
face_pil = Image.fromarray(face_crop_rgb) | |
emotion_results = emotion_model.predict(source=face_pil, conf=0.5) | |
if len(emotion_results) > 0 and len(emotion_results[0].boxes) > 0: | |
box_detect = emotion_results[0].boxes[0] | |
emotion_label = emotion_results[0].names[int(box_detect.cls)] | |
confidence = float(box_detect.conf) | |
else: | |
emotion_label = "N/A" | |
confidence = 0.0 | |
detection_info = { | |
"bbox": { | |
"x_min": x_min, | |
"y_min": y_min, | |
"x_max": x_max, | |
"y_max": y_max | |
}, | |
"emotion": emotion_label, | |
"confidence": confidence | |
} | |
detections_list.append(detection_info) | |
return detections_list | |
async def predict_frame(file: UploadFile = File(...)): | |
""" | |
Accept an image file, run face and emotion detection, annotate the image with bounding boxes | |
and emotion labels, and return the annotated image as PNG. | |
""" | |
if not file.filename.lower().endswith(('.jpg', '.jpeg', '.png')): | |
raise HTTPException(status_code=400, detail="Invalid file format. Only JPG, JPEG, and PNG are allowed.") | |
try: | |
contents = await file.read() | |
nparr = np.frombuffer(contents, np.uint8) | |
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if image is None: | |
raise HTTPException(status_code=400, detail="Invalid image file.") | |
detections = detect_emotions(image) | |
for det in detections: | |
bbox = det["bbox"] | |
label_text = f'{det["emotion"]} ({det["confidence"]:.2f})' | |
cv2.rectangle(image, (bbox["x_min"], bbox["y_min"]), (bbox["x_max"], bbox["y_max"]), (0, 255, 0), 2) | |
cv2.putText(image, label_text, (bbox["x_min"], bbox["y_min"] - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
is_success, im_buf_arr = cv2.imencode(".png", image) | |
if not is_success: | |
raise HTTPException(status_code=500, detail="Error encoding image.") | |
byte_im = im_buf_arr.tobytes() | |
buf = io.BytesIO(byte_im) | |
buf.seek(0) | |
return StreamingResponse(buf, media_type="image/png") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
finally: | |
await file.close() | |
async def predict_emotion(file: UploadFile = File(...)): | |
""" | |
Accept an image file, run face and emotion detection, and return the results as JSON. | |
The JSON response includes a list of detections with bounding box coordinates, emotion label, and confidence score. | |
""" | |
if not file.filename.lower().endswith(('.jpg', '.jpeg', '.png')): | |
raise HTTPException(status_code=400, detail="Invalid file format. Only JPG, JPEG, and PNG are allowed.") | |
try: | |
contents = await file.read() | |
nparr = np.frombuffer(contents, np.uint8) | |
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if image is None: | |
raise HTTPException(status_code=400, detail="Invalid image file.") | |
detections = detect_emotions(image) | |
return {"detections": detections} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
finally: | |
await file.close() | |
async def predict_video(file: UploadFile = File(...)): | |
""" | |
Accept a video file, process it frame-by-frame with face and emotion detection, | |
annotate each frame with bounding boxes and emotion labels, and return the annotated | |
video as an MP4 file. | |
""" | |
if not file.filename.lower().endswith(('.mp4', '.avi', '.mov')): | |
raise HTTPException(status_code=400, detail="Invalid file format. Only MP4, AVI, and MOV are allowed.") | |
try: | |
contents = await file.read() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input: | |
tmp_input.write(contents) | |
input_video_path = tmp_input.name | |
cap = cv2.VideoCapture(input_video_path) | |
if not cap.isOpened(): | |
raise HTTPException(status_code=400, detail="Could not open video file.") | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output: | |
output_video_path = tmp_output.name | |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
detections = detect_emotions(frame) | |
for det in detections: | |
bbox = det["bbox"] | |
label_text = f'{det["emotion"]} ({det["confidence"]:.2f})' | |
cv2.rectangle(frame, (bbox["x_min"], bbox["y_min"]), (bbox["x_max"], bbox["y_max"]), (0, 255, 0), 2) | |
cv2.putText(frame, label_text, (bbox["x_min"], bbox["y_min"] - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
out.write(frame) | |
cap.release() | |
out.release() | |
with open(output_video_path, "rb") as f: | |
video_bytes = f.read() | |
os.remove(input_video_path) | |
os.remove(output_video_path) | |
return StreamingResponse(io.BytesIO(video_bytes), media_type="video/mp4") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}") | |
finally: | |
await file.close() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |