import logging import logging.handlers import queue import urllib.request from pathlib import Path from typing import Literal import av import cv2 import numpy as np import PIL import streamlit as st from aiortc.contrib.media import MediaPlayer from streamlit_webrtc import ( ClientSettings, VideoTransformerBase, WebRtcMode, webrtc_streamer, ) HERE = Path(__file__).parent logger = logging.getLogger(__name__) # This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501 def download_file(url, download_to: Path, expected_size=None): # Don't download the file twice. # (If possible, verify the download using the file length.) if download_to.exists(): if expected_size: if download_to.stat().st_size == expected_size: return else: st.info(f"{url} is already downloaded.") if not st.button("Download again?"): return download_to.parent.mkdir(parents=True, exist_ok=True) # These are handles to two visual elements to animate. weights_warning, progress_bar = None, None try: weights_warning = st.warning("Downloading %s..." % url) progress_bar = st.progress(0) with open(download_to, "wb") as output_file: with urllib.request.urlopen(url) as response: length = int(response.info()["Content-Length"]) counter = 0.0 MEGABYTES = 2.0 ** 20.0 while True: data = response.read(8192) if not data: break counter += len(data) output_file.write(data) # We perform animation by overwriting the elements. weights_warning.warning( "Downloading %s... (%6.2f/%6.2f MB)" % (url, counter / MEGABYTES, length / MEGABYTES) ) progress_bar.progress(min(counter / length, 1.0)) # Finally, we remove these visual elements by calling .empty(). finally: if weights_warning is not None: weights_warning.empty() if progress_bar is not None: progress_bar.empty() def main(): st.header("WebRTC demo") object_detection_page = "Real time object detection (sendrecv)" video_filters_page = ( "Real time video transform with simple OpenCV filters (sendrecv)" ) streaming_page = ( "Consuming media files on server-side and streaming it to browser (recvonly)" ) sendonly_page = "WebRTC is sendonly and images are shown via st.image() (sendonly)" loopback_page = "Simple video loopback (sendrecv)" app_mode = st.sidebar.selectbox( "Choose the app mode", [ object_detection_page, video_filters_page, streaming_page, sendonly_page, loopback_page, ], ) st.subheader(app_mode) if app_mode == video_filters_page: app_video_filters() elif app_mode == object_detection_page: app_object_detection() elif app_mode == streaming_page: app_streaming() elif app_mode == sendonly_page: app_sendonly() elif app_mode == loopback_page: app_loopback() def app_loopback(): """ Simple video loopback """ webrtc_streamer( key="loopback", mode=WebRtcMode.SENDRECV, client_settings=WEBRTC_CLIENT_SETTINGS, video_transformer_class=None, # NoOp ) def app_video_filters(): """ Video transforms with OpenCV """ class OpenCVVideoTransformer(VideoTransformerBase): type: Literal["noop", "cartoon", "edges", "rotate"] def __init__(self) -> None: self.type = "noop" def transform(self, frame: av.VideoFrame) -> av.VideoFrame: img = frame.to_ndarray(format="bgr24") if self.type == "noop": pass elif self.type == "cartoon": # prepare color img_color = cv2.pyrDown(cv2.pyrDown(img)) for _ in range(6): img_color = cv2.bilateralFilter(img_color, 9, 9, 7) img_color = cv2.pyrUp(cv2.pyrUp(img_color)) # prepare edges img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) img_edges = cv2.adaptiveThreshold( cv2.medianBlur(img_edges, 7), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2, ) img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) # combine color and edges img = cv2.bitwise_and(img_color, img_edges) elif self.type == "edges": # perform edge detection img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) elif self.type == "rotate": # rotate image rows, cols, _ = img.shape M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) img = cv2.warpAffine(img, M, (cols, rows)) return img webrtc_ctx = webrtc_streamer( key="opencv-filter", mode=WebRtcMode.SENDRECV, client_settings=WEBRTC_CLIENT_SETTINGS, video_transformer_class=OpenCVVideoTransformer, async_transform=True, ) transform_type = st.radio( "Select transform type", ("noop", "cartoon", "edges", "rotate") ) if webrtc_ctx.video_transformer: webrtc_ctx.video_transformer.type = transform_type st.markdown( "This demo is based on " "https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " # noqa: E501 "Many thanks to the project." ) def app_object_detection(): """Object detection demo with MobileNet SSD. This model and code are based on https://github.com/robmarkcole/object-detection-app """ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501 MODEL_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.caffemodel" PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501 PROTOTXT_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.prototxt.txt" CLASSES = [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ] COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) DEFAULT_CONFIDENCE_THRESHOLD = 0.5 class NNVideoTransformer(VideoTransformerBase): confidence_threshold: float def __init__(self) -> None: self._net = cv2.dnn.readNetFromCaffe( str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH) ) self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD def _annotate_image(self, image, detections): # loop over the detections (h, w) = image.shape[:2] labels = [] for i in np.arange(0, detections.shape[2]): confidence = detections[0, 0, i, 2] if confidence > self.confidence_threshold: # extract the index of the class label from the `detections`, # then compute the (x, y)-coordinates of the bounding box for # the object idx = int(detections[0, 0, i, 1]) box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) (startX, startY, endX, endY) = box.astype("int") # display the prediction label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%" labels.append(label) cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2) y = startY - 15 if startY - 15 > 15 else startY + 15 cv2.putText( image, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2, ) return image, labels def transform(self, frame: av.VideoFrame) -> np.ndarray: image = frame.to_ndarray(format="bgr24") blob = cv2.dnn.blobFromImage( cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5 ) self._net.setInput(blob) detections = self._net.forward() annotated_image, labels = self._annotate_image(image, detections) # TODO: Show labels return annotated_image webrtc_ctx = webrtc_streamer( key="object-detection", mode=WebRtcMode.SENDRECV, client_settings=WEBRTC_CLIENT_SETTINGS, video_transformer_class=NNVideoTransformer, async_transform=True, ) confidence_threshold = st.slider( "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05 ) if webrtc_ctx.video_transformer: webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold st.markdown( "This demo uses a model and code from " "https://github.com/robmarkcole/object-detection-app. " "Many thanks to the project." ) def app_streaming(): """ Media streamings """ MEDIAFILES = { "big_buck_bunny_720p_2mb.mp4": { "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_2mb.mp4", # noqa: E501 "local_file_path": HERE / "data/big_buck_bunny_720p_2mb.mp4", "type": "video", }, "big_buck_bunny_720p_10mb.mp4": { "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_10mb.mp4", # noqa: E501 "local_file_path": HERE / "data/big_buck_bunny_720p_10mb.mp4", "type": "video", }, "file_example_MP3_700KB.mp3": { "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_700KB.mp3", # noqa: E501 "local_file_path": HERE / "data/file_example_MP3_700KB.mp3", "type": "audio", }, "file_example_MP3_5MG.mp3": { "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_5MG.mp3", # noqa: E501 "local_file_path": HERE / "data/file_example_MP3_5MG.mp3", "type": "audio", }, } media_file_label = st.radio( "Select a media file to stream", tuple(MEDIAFILES.keys()) ) media_file_info = MEDIAFILES[media_file_label] download_file(media_file_info["url"], media_file_info["local_file_path"]) def create_player(): return MediaPlayer(str(media_file_info["local_file_path"])) # NOTE: To stream the video from webcam, use the code below. # return MediaPlayer( # "1:none", # format="avfoundation", # options={"framerate": "30", "video_size": "1280x720"}, # ) WEBRTC_CLIENT_SETTINGS.update( { "fmedia_stream_constraints": { "video": media_file_info["type"] == "video", "audio": media_file_info["type"] == "audio", } } ) webrtc_streamer( key=f"media-streaming-{media_file_label}", mode=WebRtcMode.RECVONLY, client_settings=WEBRTC_CLIENT_SETTINGS, player_factory=create_player, ) def app_sendonly(): """A sample to use WebRTC in sendonly mode to transfer frames from the browser to the server and to render frames via `st.image`.""" webrtc_ctx = webrtc_streamer( key="loopback", mode=WebRtcMode.SENDONLY, client_settings=WEBRTC_CLIENT_SETTINGS, ) if webrtc_ctx.video_receiver: image_loc = st.empty() while True: try: frame = webrtc_ctx.video_receiver.frames_queue.get(timeout=1) except queue.Empty: print("Queue is empty. Stop the loop.") webrtc_ctx.video_receiver.stop() break img = frame.to_ndarray(format="bgr24") img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image_loc.image(img) WEBRTC_CLIENT_SETTINGS = ClientSettings( rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, media_stream_constraints={"video": True, "audio": True}, ) if __name__ == "__main__": logging.basicConfig( format="[%(asctime)s] %(levelname)7s from %(name)s in %(filename)s:%(lineno)d: " "%(message)s", force=True, ) logger.setLevel(level=logging.DEBUG) st_webrtc_logger = logging.getLogger("streamlit_webrtc") st_webrtc_logger.setLevel(logging.DEBUG) main()