Skip to content

Commit c585553

Browse files
Merge pull request #751 from roboflow/fix/speed-up-webrtc-pipeline
Speed-up webrtc inference pipeline
2 parents 1fa4c96 + f2d6e2e commit c585553

File tree

3 files changed

+109
-71
lines changed

3 files changed

+109
-71
lines changed

inference/core/interfaces/stream_manager/manager_app/inference_pipeline_manager.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
import os
33
import signal
44
import threading
5-
import time
6-
from collections import deque
75
from dataclasses import asdict
86
from functools import partial
97
from multiprocessing import Process, Queue
10-
from threading import Event, Lock
8+
from threading import Event
119
from types import FrameType
12-
from typing import Deque, Dict, Optional, Tuple
10+
from typing import Dict, Optional, Tuple
1311

1412
from pydantic import ValidationError
1513

@@ -47,6 +45,7 @@
4745
WebRTCVideoFrameProducer,
4846
init_rtc_peer_connection,
4947
)
48+
from inference.core.utils.async_utils import Queue as SyncAsyncQueue
5049
from inference.core.workflows.execution_engine.entities.base import WorkflowImageData
5150

5251

@@ -199,15 +198,6 @@ def _start_webrtc(self, request_id: str, payload: dict):
199198
parsed_payload = InitialiseWebRTCPipelinePayload.model_validate(payload)
200199
watchdog = BasePipelineWatchDog()
201200

202-
webrtc_offer = parsed_payload.webrtc_offer
203-
webcam_fps = parsed_payload.webcam_fps
204-
to_inference_queue = deque()
205-
to_inference_lock = Lock()
206-
from_inference_queue = deque()
207-
from_inference_lock = Lock()
208-
209-
stop_event = Event()
210-
211201
def start_loop(loop: asyncio.AbstractEventLoop):
212202
asyncio.set_event_loop(loop)
213203
loop.run_forever()
@@ -216,15 +206,21 @@ def start_loop(loop: asyncio.AbstractEventLoop):
216206
t = threading.Thread(target=start_loop, args=(loop,), daemon=True)
217207
t.start()
218208

209+
webrtc_offer = parsed_payload.webrtc_offer
210+
webcam_fps = parsed_payload.webcam_fps
211+
to_inference_queue = SyncAsyncQueue(loop=loop)
212+
from_inference_queue = SyncAsyncQueue(loop=loop)
213+
214+
stop_event = Event()
215+
219216
future = asyncio.run_coroutine_threadsafe(
220217
init_rtc_peer_connection(
221218
webrtc_offer=webrtc_offer,
222219
to_inference_queue=to_inference_queue,
223-
to_inference_lock=to_inference_lock,
224220
from_inference_queue=from_inference_queue,
225-
from_inference_lock=from_inference_lock,
226221
webrtc_peer_timeout=parsed_payload.webrtc_peer_timeout,
227222
feedback_stop_event=stop_event,
223+
asyncio_loop=loop,
228224
webcam_fps=webcam_fps,
229225
),
230226
loop,
@@ -233,7 +229,6 @@ def start_loop(loop: asyncio.AbstractEventLoop):
233229

234230
webrtc_producer = partial(
235231
WebRTCVideoFrameProducer,
236-
to_inference_lock=to_inference_lock,
237232
to_inference_queue=to_inference_queue,
238233
stop_event=stop_event,
239234
webrtc_video_transform_track=peer_connection.video_transform_track,
@@ -242,10 +237,9 @@ def start_loop(loop: asyncio.AbstractEventLoop):
242237
def webrtc_sink(
243238
prediction: Dict[str, WorkflowImageData], video_frame: VideoFrame
244239
) -> None:
245-
with from_inference_lock:
246-
from_inference_queue.appendleft(
247-
prediction[parsed_payload.stream_output[0]].numpy_image
248-
)
240+
from_inference_queue.sync_put(
241+
prediction[parsed_payload.stream_output[0]].numpy_image
242+
)
249243

250244
buffer_sink = InMemoryBufferSink.init(
251245
queue_size=parsed_payload.sink_configuration.results_buffer_size,

inference/core/interfaces/stream_manager/manager_app/webrtc.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import asyncio
22
import concurrent.futures
33
import time
4-
from collections import deque
5-
from threading import Event, Lock
6-
from typing import Deque, Dict, Optional, Tuple
4+
from threading import Event
5+
from typing import Dict, Optional, Tuple
76

87
import numpy as np
98
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
@@ -18,17 +17,16 @@
1817
VideoFrameProducer,
1918
)
2019
from inference.core.interfaces.stream_manager.manager_app.entities import WebRTCOffer
21-
from inference.core.utils.async_utils import async_lock
20+
from inference.core.utils.async_utils import Queue as SyncAsyncQueue
2221
from inference.core.utils.function import experimental
2322

2423

2524
class VideoTransformTrack(VideoStreamTrack):
2625
def __init__(
2726
self,
28-
to_inference_queue: Deque,
29-
to_inference_lock: Lock,
30-
from_inference_queue: Deque,
31-
from_inference_lock: Lock,
27+
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
28+
from_inference_queue: "SyncAsyncQueue[np.ndarray]",
29+
asyncio_loop: asyncio.AbstractEventLoop,
3230
webrtc_peer_timeout: float = 1,
3331
fps_probe_frames: int = 10,
3432
webcam_fps: Optional[float] = None,
@@ -43,10 +41,9 @@ def __init__(
4341
self.track: Optional[RemoteStreamTrack] = None
4442
self._id = time.time_ns()
4543
self._processed = 0
46-
self.to_inference_queue: Deque = to_inference_queue
47-
self.from_inference_queue: Deque = from_inference_queue
48-
self.to_inference_lock: Lock = to_inference_lock
49-
self.from_inference_lock: Lock = from_inference_lock
44+
self.to_inference_queue: "SyncAsyncQueue[VideoFrame]" = to_inference_queue
45+
self.from_inference_queue: "SyncAsyncQueue[np.ndarray]" = from_inference_queue
46+
self._asyncio_loop = asyncio_loop
5047
self._pool = concurrent.futures.ThreadPoolExecutor()
5148
self._track_active: bool = True
5249
self._fps_probe_frames = fps_probe_frames
@@ -84,40 +81,41 @@ async def recv(self):
8481
"All frames probed in the same time - could not calculate fps."
8582
)
8683
raise MediaStreamError
87-
self.incoming_stream_fps = 9 / (t2 - t1)
84+
self.incoming_stream_fps = (self._fps_probe_frames - 1) / (t2 - t1)
8885
logger.debug("Incoming stream fps: %s", self.incoming_stream_fps)
8986

90-
try:
91-
frame: VideoFrame = await asyncio.wait_for(
92-
self.track.recv(), self.webrtc_peer_timeout
93-
)
94-
except (asyncio.TimeoutError, MediaStreamError):
95-
logger.info(
96-
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
97-
)
98-
self.close()
99-
raise MediaStreamError
100-
img = frame.to_ndarray(format="bgr24")
101-
102-
dropped = 0
103-
async with async_lock(lock=self.to_inference_lock, pool=self._pool):
104-
self.to_inference_queue.appendleft(img)
105-
while self._track_active and not self.from_inference_queue:
87+
while self._track_active:
10688
try:
10789
frame: VideoFrame = await asyncio.wait_for(
10890
self.track.recv(), self.webrtc_peer_timeout
10991
)
11092
except (asyncio.TimeoutError, MediaStreamError):
111-
self.close()
11293
logger.info(
11394
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
11495
)
96+
self.close()
11597
raise MediaStreamError
116-
dropped += 1
117-
async with async_lock(lock=self.from_inference_lock, pool=self._pool):
118-
res = self.from_inference_queue.pop()
11998

120-
logger.debug("Dropping %s every inference", dropped)
99+
await self.to_inference_queue.async_put(frame)
100+
101+
from_inference_queue_empty = await self.from_inference_queue.async_empty()
102+
if not from_inference_queue_empty:
103+
break
104+
105+
while self._track_active:
106+
try:
107+
res: np.ndarray = await asyncio.wait_for(
108+
self.from_inference_queue.async_get(), self.webrtc_peer_timeout
109+
)
110+
break
111+
except asyncio.TimeoutError:
112+
continue
113+
if not self._track_active:
114+
logger.info(
115+
"Received close request while waiting to receive frames from inference pipeline; assuming termination."
116+
)
117+
raise MediaStreamError
118+
121119
new_frame = VideoFrame.from_ndarray(res, format="bgr24")
122120
new_frame.pts = frame.pts
123121
new_frame.time_base = frame.time_base
@@ -132,31 +130,35 @@ class WebRTCVideoFrameProducer(VideoFrameProducer):
132130
)
133131
def __init__(
134132
self,
135-
to_inference_queue: deque,
136-
to_inference_lock: Lock,
133+
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
137134
stop_event: Event,
138135
webrtc_video_transform_track: VideoTransformTrack,
139136
):
140-
self.to_inference_queue: deque = to_inference_queue
141-
self.to_inference_lock: Lock = to_inference_lock
137+
self.to_inference_queue: "SyncAsyncQueue[VideoFrame]" = to_inference_queue
142138
self._stop_event = stop_event
143139
self._w: Optional[int] = None
144140
self._h: Optional[int] = None
145141
self._video_transform_track = webrtc_video_transform_track
146142
self._is_opened = True
147143

148144
def grab(self) -> bool:
149-
return self._is_opened
145+
if self._stop_event.is_set():
146+
logger.info("Received termination signal, closing.")
147+
self._is_opened = False
148+
return False
149+
150+
self.to_inference_queue.sync_get()
151+
return True
150152

151153
def retrieve(self) -> Tuple[bool, np.ndarray]:
152-
while not self._stop_event.is_set() and not self.to_inference_queue:
153-
time.sleep(0.1)
154154
if self._stop_event.is_set():
155155
logger.info("Received termination signal, closing.")
156156
self._is_opened = False
157157
return False, None
158-
with self.to_inference_lock:
159-
img = self.to_inference_queue.pop()
158+
159+
frame: VideoFrame = self.to_inference_queue.sync_get()
160+
img = frame.to_ndarray(format="bgr24")
161+
160162
return True, img
161163

162164
def release(self):
@@ -189,19 +191,17 @@ def __init__(self, video_transform_track: VideoTransformTrack, *args, **kwargs):
189191

190192
async def init_rtc_peer_connection(
191193
webrtc_offer: WebRTCOffer,
192-
to_inference_queue: Deque,
193-
to_inference_lock: Lock,
194-
from_inference_queue: Deque,
195-
from_inference_lock: Lock,
194+
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
195+
from_inference_queue: "SyncAsyncQueue[np.ndarray]",
196196
webrtc_peer_timeout: float,
197197
feedback_stop_event: Event,
198+
asyncio_loop: asyncio.AbstractEventLoop,
198199
webcam_fps: Optional[float] = None,
199200
) -> RTCPeerConnectionWithFPS:
200201
video_transform_track = VideoTransformTrack(
201-
to_inference_lock=to_inference_lock,
202202
to_inference_queue=to_inference_queue,
203-
from_inference_lock=from_inference_lock,
204203
from_inference_queue=from_inference_queue,
204+
asyncio_loop=asyncio_loop,
205205
webrtc_peer_timeout=webrtc_peer_timeout,
206206
webcam_fps=webcam_fps,
207207
)

inference/core/utils/async_utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,57 @@
22
import concurrent.futures
33
import contextlib
44
from threading import Lock
5+
from typing import Optional, Union
56

67

78
@contextlib.asynccontextmanager
8-
async def async_lock(lock: Lock, pool: concurrent.futures.ThreadPoolExecutor):
9-
loop = asyncio.get_event_loop()
9+
async def async_lock(
10+
lock: Union[Lock],
11+
pool: concurrent.futures.ThreadPoolExecutor,
12+
loop: Optional[asyncio.AbstractEventLoop] = None,
13+
):
14+
if not loop:
15+
loop = asyncio.get_event_loop()
1016
await loop.run_in_executor(pool, lock.acquire)
1117
try:
1218
yield # the lock is held
1319
finally:
1420
lock.release()
21+
22+
23+
async def create_async_queue() -> asyncio.Queue:
24+
return asyncio.Queue()
25+
26+
27+
class Queue:
28+
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
29+
self._loop = loop
30+
if not self._loop:
31+
self._loop = asyncio.get_running_loop()
32+
self._queue = asyncio.run_coroutine_threadsafe(
33+
create_async_queue(), self._loop
34+
).result()
35+
36+
def sync_put_nowait(self, item):
37+
self._loop.call_soon(self._queue.put_nowait, item)
38+
39+
def sync_put(self, item):
40+
asyncio.run_coroutine_threadsafe(self._queue.put(item), self._loop).result()
41+
42+
def sync_get(self):
43+
return asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()
44+
45+
def sync_empty(self):
46+
return self._queue.empty()
47+
48+
def async_put_nowait(self, item):
49+
self._queue.put_nowait(item)
50+
51+
async def async_put(self, item):
52+
await self._queue.put(item)
53+
54+
async def async_get(self):
55+
return await self._queue.get()
56+
57+
async def async_empty(self):
58+
return self._queue.empty()

0 commit comments

Comments
 (0)