Spaces:
Running
Running
initial add files commit
Browse files- app.py +210 -0
- model/yolo11m_affectnet_best.pt +3 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
+
from ultralytics import YOLO
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import mediapipe as mp
|
7 |
+
import numpy as np
|
8 |
+
import io
|
9 |
+
import tempfile
|
10 |
+
import os
|
11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
app.add_middleware(
|
16 |
+
CORSMiddleware,
|
17 |
+
allow_origins=["*"], # In production, specify your allowed domains
|
18 |
+
allow_credentials=True,
|
19 |
+
allow_methods=["*"],
|
20 |
+
allow_headers=["*"],
|
21 |
+
)
|
22 |
+
|
23 |
+
# Health Check Endpoint
|
24 |
+
@app.get("/")
|
25 |
+
def root():
|
26 |
+
"""
|
27 |
+
Health check endpoint to verify that the API is running.
|
28 |
+
"""
|
29 |
+
return {"message": "YOLO11 Emotion Detection API is live!"}
|
30 |
+
|
31 |
+
# Load your custom YOLO emotion detection model
|
32 |
+
try:
|
33 |
+
emotion_model = YOLO("model/yolo11m_affectnet_best.pt")
|
34 |
+
except Exception as e:
|
35 |
+
raise Exception(f"Error loading emotion model: {e}")
|
36 |
+
|
37 |
+
def detect_emotions(image):
|
38 |
+
"""
|
39 |
+
Given an OpenCV BGR image, detect faces using Mediapipe and perform emotion detection
|
40 |
+
on each face crop using the YOLO model.
|
41 |
+
Returns a list of detections with bounding box and emotion details.
|
42 |
+
"""
|
43 |
+
height, width, _ = image.shape
|
44 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
45 |
+
detections_list = []
|
46 |
+
|
47 |
+
mp_face_detection = mp.solutions.face_detection
|
48 |
+
with mp_face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5) as face_detection:
|
49 |
+
results = face_detection.process(image_rgb)
|
50 |
+
|
51 |
+
if results.detections:
|
52 |
+
for detection in results.detections:
|
53 |
+
bbox = detection.location_data.relative_bounding_box
|
54 |
+
x_min = int(bbox.xmin * width)
|
55 |
+
y_min = int(bbox.ymin * height)
|
56 |
+
box_width = int(bbox.width * width)
|
57 |
+
box_height = int(bbox.height * height)
|
58 |
+
x_max = x_min + box_width
|
59 |
+
y_max = y_min + box_height
|
60 |
+
|
61 |
+
x_min = max(0, x_min)
|
62 |
+
y_min = max(0, y_min)
|
63 |
+
x_max = min(width, x_max)
|
64 |
+
y_max = min(height, y_max)
|
65 |
+
|
66 |
+
face_crop = image[y_min:y_max, x_min:x_max]
|
67 |
+
if face_crop.size == 0:
|
68 |
+
continue
|
69 |
+
face_crop_rgb = cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB)
|
70 |
+
face_pil = Image.fromarray(face_crop_rgb)
|
71 |
+
|
72 |
+
emotion_results = emotion_model.predict(source=face_pil, conf=0.5)
|
73 |
+
if len(emotion_results) > 0 and len(emotion_results[0].boxes) > 0:
|
74 |
+
box_detect = emotion_results[0].boxes[0]
|
75 |
+
emotion_label = emotion_results[0].names[int(box_detect.cls)]
|
76 |
+
confidence = float(box_detect.conf)
|
77 |
+
else:
|
78 |
+
emotion_label = "N/A"
|
79 |
+
confidence = 0.0
|
80 |
+
|
81 |
+
detection_info = {
|
82 |
+
"bbox": {
|
83 |
+
"x_min": x_min,
|
84 |
+
"y_min": y_min,
|
85 |
+
"x_max": x_max,
|
86 |
+
"y_max": y_max
|
87 |
+
},
|
88 |
+
"emotion": emotion_label,
|
89 |
+
"confidence": confidence
|
90 |
+
}
|
91 |
+
detections_list.append(detection_info)
|
92 |
+
return detections_list
|
93 |
+
|
94 |
+
@app.post("/predict_frame")
|
95 |
+
async def predict_frame(file: UploadFile = File(...)):
|
96 |
+
"""
|
97 |
+
Accept an image file, run face and emotion detection, annotate the image with bounding boxes
|
98 |
+
and emotion labels, and return the annotated image as PNG.
|
99 |
+
"""
|
100 |
+
if not file.filename.lower().endswith(('.jpg', '.jpeg', '.png')):
|
101 |
+
raise HTTPException(status_code=400, detail="Invalid file format. Only JPG, JPEG, and PNG are allowed.")
|
102 |
+
try:
|
103 |
+
contents = await file.read()
|
104 |
+
nparr = np.frombuffer(contents, np.uint8)
|
105 |
+
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
106 |
+
if image is None:
|
107 |
+
raise HTTPException(status_code=400, detail="Invalid image file.")
|
108 |
+
|
109 |
+
detections = detect_emotions(image)
|
110 |
+
|
111 |
+
for det in detections:
|
112 |
+
bbox = det["bbox"]
|
113 |
+
label_text = f'{det["emotion"]} ({det["confidence"]:.2f})'
|
114 |
+
cv2.rectangle(image, (bbox["x_min"], bbox["y_min"]), (bbox["x_max"], bbox["y_max"]), (0, 255, 0), 2)
|
115 |
+
cv2.putText(image, label_text, (bbox["x_min"], bbox["y_min"] - 10),
|
116 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
117 |
+
|
118 |
+
is_success, im_buf_arr = cv2.imencode(".png", image)
|
119 |
+
if not is_success:
|
120 |
+
raise HTTPException(status_code=500, detail="Error encoding image.")
|
121 |
+
byte_im = im_buf_arr.tobytes()
|
122 |
+
buf = io.BytesIO(byte_im)
|
123 |
+
buf.seek(0)
|
124 |
+
return StreamingResponse(buf, media_type="image/png")
|
125 |
+
except Exception as e:
|
126 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
127 |
+
finally:
|
128 |
+
await file.close()
|
129 |
+
|
130 |
+
@app.post("/predict_emotion")
|
131 |
+
async def predict_emotion(file: UploadFile = File(...)):
|
132 |
+
"""
|
133 |
+
Accept an image file, run face and emotion detection, and return the results as JSON.
|
134 |
+
The JSON response includes a list of detections with bounding box coordinates, emotion label, and confidence score.
|
135 |
+
"""
|
136 |
+
if not file.filename.lower().endswith(('.jpg', '.jpeg', '.png')):
|
137 |
+
raise HTTPException(status_code=400, detail="Invalid file format. Only JPG, JPEG, and PNG are allowed.")
|
138 |
+
try:
|
139 |
+
contents = await file.read()
|
140 |
+
nparr = np.frombuffer(contents, np.uint8)
|
141 |
+
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
142 |
+
if image is None:
|
143 |
+
raise HTTPException(status_code=400, detail="Invalid image file.")
|
144 |
+
detections = detect_emotions(image)
|
145 |
+
return {"detections": detections}
|
146 |
+
except Exception as e:
|
147 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
148 |
+
finally:
|
149 |
+
await file.close()
|
150 |
+
|
151 |
+
@app.post("/predict_video")
|
152 |
+
async def predict_video(file: UploadFile = File(...)):
|
153 |
+
"""
|
154 |
+
Accept a video file, process it frame-by-frame with face and emotion detection,
|
155 |
+
annotate each frame with bounding boxes and emotion labels, and return the annotated
|
156 |
+
video as an MP4 file.
|
157 |
+
"""
|
158 |
+
if not file.filename.lower().endswith(('.mp4', '.avi', '.mov')):
|
159 |
+
raise HTTPException(status_code=400, detail="Invalid file format. Only MP4, AVI, and MOV are allowed.")
|
160 |
+
try:
|
161 |
+
contents = await file.read()
|
162 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
|
163 |
+
tmp_input.write(contents)
|
164 |
+
input_video_path = tmp_input.name
|
165 |
+
|
166 |
+
cap = cv2.VideoCapture(input_video_path)
|
167 |
+
if not cap.isOpened():
|
168 |
+
raise HTTPException(status_code=400, detail="Could not open video file.")
|
169 |
+
|
170 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
171 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
172 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
173 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
174 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output:
|
175 |
+
output_video_path = tmp_output.name
|
176 |
+
|
177 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
178 |
+
|
179 |
+
while True:
|
180 |
+
ret, frame = cap.read()
|
181 |
+
if not ret:
|
182 |
+
break
|
183 |
+
|
184 |
+
detections = detect_emotions(frame)
|
185 |
+
for det in detections:
|
186 |
+
bbox = det["bbox"]
|
187 |
+
label_text = f'{det["emotion"]} ({det["confidence"]:.2f})'
|
188 |
+
cv2.rectangle(frame, (bbox["x_min"], bbox["y_min"]), (bbox["x_max"], bbox["y_max"]), (0, 255, 0), 2)
|
189 |
+
cv2.putText(frame, label_text, (bbox["x_min"], bbox["y_min"] - 10),
|
190 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
191 |
+
out.write(frame)
|
192 |
+
|
193 |
+
cap.release()
|
194 |
+
out.release()
|
195 |
+
|
196 |
+
with open(output_video_path, "rb") as f:
|
197 |
+
video_bytes = f.read()
|
198 |
+
|
199 |
+
os.remove(input_video_path)
|
200 |
+
os.remove(output_video_path)
|
201 |
+
|
202 |
+
return StreamingResponse(io.BytesIO(video_bytes), media_type="video/mp4")
|
203 |
+
except Exception as e:
|
204 |
+
raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")
|
205 |
+
finally:
|
206 |
+
await file.close()
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
import uvicorn
|
210 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
model/yolo11m_affectnet_best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d9103ead25cfa9307e1580c95d19d744251b3cbc03b8945a7c74b239309d105
|
3 |
+
size 40481573
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn
|
3 |
+
python-multipart
|
4 |
+
Pillow
|
5 |
+
opencv-python
|
6 |
+
mediapipe
|
7 |
+
numpy
|
8 |
+
ultralytics
|