Skip to content

Commit

Permalink
Merge pull request #751 from roboflow/fix/speed-up-webrtc-pipeline
Browse files Browse the repository at this point in the history
Speed-up webrtc inference pipeline
  • Loading branch information
PawelPeczek-Roboflow authored Oct 17, 2024
2 parents 1fa4c96 + f2d6e2e commit c585553
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import os
import signal
import threading
import time
from collections import deque
from dataclasses import asdict
from functools import partial
from multiprocessing import Process, Queue
from threading import Event, Lock
from threading import Event
from types import FrameType
from typing import Deque, Dict, Optional, Tuple
from typing import Dict, Optional, Tuple

from pydantic import ValidationError

Expand Down Expand Up @@ -47,6 +45,7 @@
WebRTCVideoFrameProducer,
init_rtc_peer_connection,
)
from inference.core.utils.async_utils import Queue as SyncAsyncQueue
from inference.core.workflows.execution_engine.entities.base import WorkflowImageData


Expand Down Expand Up @@ -199,15 +198,6 @@ def _start_webrtc(self, request_id: str, payload: dict):
parsed_payload = InitialiseWebRTCPipelinePayload.model_validate(payload)
watchdog = BasePipelineWatchDog()

webrtc_offer = parsed_payload.webrtc_offer
webcam_fps = parsed_payload.webcam_fps
to_inference_queue = deque()
to_inference_lock = Lock()
from_inference_queue = deque()
from_inference_lock = Lock()

stop_event = Event()

def start_loop(loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()
Expand All @@ -216,15 +206,21 @@ def start_loop(loop: asyncio.AbstractEventLoop):
t = threading.Thread(target=start_loop, args=(loop,), daemon=True)
t.start()

webrtc_offer = parsed_payload.webrtc_offer
webcam_fps = parsed_payload.webcam_fps
to_inference_queue = SyncAsyncQueue(loop=loop)
from_inference_queue = SyncAsyncQueue(loop=loop)

stop_event = Event()

future = asyncio.run_coroutine_threadsafe(
init_rtc_peer_connection(
webrtc_offer=webrtc_offer,
to_inference_queue=to_inference_queue,
to_inference_lock=to_inference_lock,
from_inference_queue=from_inference_queue,
from_inference_lock=from_inference_lock,
webrtc_peer_timeout=parsed_payload.webrtc_peer_timeout,
feedback_stop_event=stop_event,
asyncio_loop=loop,
webcam_fps=webcam_fps,
),
loop,
Expand All @@ -233,7 +229,6 @@ def start_loop(loop: asyncio.AbstractEventLoop):

webrtc_producer = partial(
WebRTCVideoFrameProducer,
to_inference_lock=to_inference_lock,
to_inference_queue=to_inference_queue,
stop_event=stop_event,
webrtc_video_transform_track=peer_connection.video_transform_track,
Expand All @@ -242,10 +237,9 @@ def start_loop(loop: asyncio.AbstractEventLoop):
def webrtc_sink(
prediction: Dict[str, WorkflowImageData], video_frame: VideoFrame
) -> None:
with from_inference_lock:
from_inference_queue.appendleft(
prediction[parsed_payload.stream_output[0]].numpy_image
)
from_inference_queue.sync_put(
prediction[parsed_payload.stream_output[0]].numpy_image
)

buffer_sink = InMemoryBufferSink.init(
queue_size=parsed_payload.sink_configuration.results_buffer_size,
Expand Down
98 changes: 49 additions & 49 deletions inference/core/interfaces/stream_manager/manager_app/webrtc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
import concurrent.futures
import time
from collections import deque
from threading import Event, Lock
from typing import Deque, Dict, Optional, Tuple
from threading import Event
from typing import Dict, Optional, Tuple

import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
Expand All @@ -18,17 +17,16 @@
VideoFrameProducer,
)
from inference.core.interfaces.stream_manager.manager_app.entities import WebRTCOffer
from inference.core.utils.async_utils import async_lock
from inference.core.utils.async_utils import Queue as SyncAsyncQueue
from inference.core.utils.function import experimental


class VideoTransformTrack(VideoStreamTrack):
def __init__(
self,
to_inference_queue: Deque,
to_inference_lock: Lock,
from_inference_queue: Deque,
from_inference_lock: Lock,
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
from_inference_queue: "SyncAsyncQueue[np.ndarray]",
asyncio_loop: asyncio.AbstractEventLoop,
webrtc_peer_timeout: float = 1,
fps_probe_frames: int = 10,
webcam_fps: Optional[float] = None,
Expand All @@ -43,10 +41,9 @@ def __init__(
self.track: Optional[RemoteStreamTrack] = None
self._id = time.time_ns()
self._processed = 0
self.to_inference_queue: Deque = to_inference_queue
self.from_inference_queue: Deque = from_inference_queue
self.to_inference_lock: Lock = to_inference_lock
self.from_inference_lock: Lock = from_inference_lock
self.to_inference_queue: "SyncAsyncQueue[VideoFrame]" = to_inference_queue
self.from_inference_queue: "SyncAsyncQueue[np.ndarray]" = from_inference_queue
self._asyncio_loop = asyncio_loop
self._pool = concurrent.futures.ThreadPoolExecutor()
self._track_active: bool = True
self._fps_probe_frames = fps_probe_frames
Expand Down Expand Up @@ -84,40 +81,41 @@ async def recv(self):
"All frames probed in the same time - could not calculate fps."
)
raise MediaStreamError
self.incoming_stream_fps = 9 / (t2 - t1)
self.incoming_stream_fps = (self._fps_probe_frames - 1) / (t2 - t1)
logger.debug("Incoming stream fps: %s", self.incoming_stream_fps)

try:
frame: VideoFrame = await asyncio.wait_for(
self.track.recv(), self.webrtc_peer_timeout
)
except (asyncio.TimeoutError, MediaStreamError):
logger.info(
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
)
self.close()
raise MediaStreamError
img = frame.to_ndarray(format="bgr24")

dropped = 0
async with async_lock(lock=self.to_inference_lock, pool=self._pool):
self.to_inference_queue.appendleft(img)
while self._track_active and not self.from_inference_queue:
while self._track_active:
try:
frame: VideoFrame = await asyncio.wait_for(
self.track.recv(), self.webrtc_peer_timeout
)
except (asyncio.TimeoutError, MediaStreamError):
self.close()
logger.info(
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
)
self.close()
raise MediaStreamError
dropped += 1
async with async_lock(lock=self.from_inference_lock, pool=self._pool):
res = self.from_inference_queue.pop()

logger.debug("Dropping %s every inference", dropped)
await self.to_inference_queue.async_put(frame)

from_inference_queue_empty = await self.from_inference_queue.async_empty()
if not from_inference_queue_empty:
break

while self._track_active:
try:
res: np.ndarray = await asyncio.wait_for(
self.from_inference_queue.async_get(), self.webrtc_peer_timeout
)
break
except asyncio.TimeoutError:
continue
if not self._track_active:
logger.info(
"Received close request while waiting to receive frames from inference pipeline; assuming termination."
)
raise MediaStreamError

new_frame = VideoFrame.from_ndarray(res, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
Expand All @@ -132,31 +130,35 @@ class WebRTCVideoFrameProducer(VideoFrameProducer):
)
def __init__(
self,
to_inference_queue: deque,
to_inference_lock: Lock,
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
stop_event: Event,
webrtc_video_transform_track: VideoTransformTrack,
):
self.to_inference_queue: deque = to_inference_queue
self.to_inference_lock: Lock = to_inference_lock
self.to_inference_queue: "SyncAsyncQueue[VideoFrame]" = to_inference_queue
self._stop_event = stop_event
self._w: Optional[int] = None
self._h: Optional[int] = None
self._video_transform_track = webrtc_video_transform_track
self._is_opened = True

def grab(self) -> bool:
return self._is_opened
if self._stop_event.is_set():
logger.info("Received termination signal, closing.")
self._is_opened = False
return False

self.to_inference_queue.sync_get()
return True

def retrieve(self) -> Tuple[bool, np.ndarray]:
while not self._stop_event.is_set() and not self.to_inference_queue:
time.sleep(0.1)
if self._stop_event.is_set():
logger.info("Received termination signal, closing.")
self._is_opened = False
return False, None
with self.to_inference_lock:
img = self.to_inference_queue.pop()

frame: VideoFrame = self.to_inference_queue.sync_get()
img = frame.to_ndarray(format="bgr24")

return True, img

def release(self):
Expand Down Expand Up @@ -189,19 +191,17 @@ def __init__(self, video_transform_track: VideoTransformTrack, *args, **kwargs):

async def init_rtc_peer_connection(
webrtc_offer: WebRTCOffer,
to_inference_queue: Deque,
to_inference_lock: Lock,
from_inference_queue: Deque,
from_inference_lock: Lock,
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
from_inference_queue: "SyncAsyncQueue[np.ndarray]",
webrtc_peer_timeout: float,
feedback_stop_event: Event,
asyncio_loop: asyncio.AbstractEventLoop,
webcam_fps: Optional[float] = None,
) -> RTCPeerConnectionWithFPS:
video_transform_track = VideoTransformTrack(
to_inference_lock=to_inference_lock,
to_inference_queue=to_inference_queue,
from_inference_lock=from_inference_lock,
from_inference_queue=from_inference_queue,
asyncio_loop=asyncio_loop,
webrtc_peer_timeout=webrtc_peer_timeout,
webcam_fps=webcam_fps,
)
Expand Down
48 changes: 46 additions & 2 deletions inference/core/utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,57 @@
import concurrent.futures
import contextlib
from threading import Lock
from typing import Optional, Union


@contextlib.asynccontextmanager
async def async_lock(lock: Lock, pool: concurrent.futures.ThreadPoolExecutor):
loop = asyncio.get_event_loop()
async def async_lock(
lock: Union[Lock],
pool: concurrent.futures.ThreadPoolExecutor,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
if not loop:
loop = asyncio.get_event_loop()
await loop.run_in_executor(pool, lock.acquire)
try:
yield # the lock is held
finally:
lock.release()


async def create_async_queue() -> asyncio.Queue:
return asyncio.Queue()


class Queue:
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop
if not self._loop:
self._loop = asyncio.get_running_loop()
self._queue = asyncio.run_coroutine_threadsafe(
create_async_queue(), self._loop
).result()

def sync_put_nowait(self, item):
self._loop.call_soon(self._queue.put_nowait, item)

def sync_put(self, item):
asyncio.run_coroutine_threadsafe(self._queue.put(item), self._loop).result()

def sync_get(self):
return asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()

def sync_empty(self):
return self._queue.empty()

def async_put_nowait(self, item):
self._queue.put_nowait(item)

async def async_put(self, item):
await self._queue.put(item)

async def async_get(self):
return await self._queue.get()

async def async_empty(self):
return self._queue.empty()

0 comments on commit c585553

Please sign in to comment.