1
1
import asyncio
2
2
import concurrent .futures
3
3
import time
4
- from collections import deque
5
- from threading import Event , Lock
6
- from typing import Deque , Dict , Optional , Tuple
4
+ from threading import Event
5
+ from typing import Dict , Optional , Tuple
7
6
8
7
import numpy as np
9
8
from aiortc import RTCPeerConnection , RTCSessionDescription , VideoStreamTrack
18
17
VideoFrameProducer ,
19
18
)
20
19
from inference .core .interfaces .stream_manager .manager_app .entities import WebRTCOffer
21
- from inference .core .utils .async_utils import async_lock
20
+ from inference .core .utils .async_utils import Queue as SyncAsyncQueue
22
21
from inference .core .utils .function import experimental
23
22
24
23
25
24
class VideoTransformTrack (VideoStreamTrack ):
26
25
def __init__ (
27
26
self ,
28
- to_inference_queue : Deque ,
29
- to_inference_lock : Lock ,
30
- from_inference_queue : Deque ,
31
- from_inference_lock : Lock ,
27
+ to_inference_queue : "SyncAsyncQueue[VideoFrame]" ,
28
+ from_inference_queue : "SyncAsyncQueue[np.ndarray]" ,
29
+ asyncio_loop : asyncio .AbstractEventLoop ,
32
30
webrtc_peer_timeout : float = 1 ,
33
31
fps_probe_frames : int = 10 ,
34
32
webcam_fps : Optional [float ] = None ,
@@ -43,10 +41,9 @@ def __init__(
43
41
self .track : Optional [RemoteStreamTrack ] = None
44
42
self ._id = time .time_ns ()
45
43
self ._processed = 0
46
- self .to_inference_queue : Deque = to_inference_queue
47
- self .from_inference_queue : Deque = from_inference_queue
48
- self .to_inference_lock : Lock = to_inference_lock
49
- self .from_inference_lock : Lock = from_inference_lock
44
+ self .to_inference_queue : "SyncAsyncQueue[VideoFrame]" = to_inference_queue
45
+ self .from_inference_queue : "SyncAsyncQueue[np.ndarray]" = from_inference_queue
46
+ self ._asyncio_loop = asyncio_loop
50
47
self ._pool = concurrent .futures .ThreadPoolExecutor ()
51
48
self ._track_active : bool = True
52
49
self ._fps_probe_frames = fps_probe_frames
@@ -84,40 +81,41 @@ async def recv(self):
84
81
"All frames probed in the same time - could not calculate fps."
85
82
)
86
83
raise MediaStreamError
87
- self .incoming_stream_fps = 9 / (t2 - t1 )
84
+ self .incoming_stream_fps = ( self . _fps_probe_frames - 1 ) / (t2 - t1 )
88
85
logger .debug ("Incoming stream fps: %s" , self .incoming_stream_fps )
89
86
90
- try :
91
- frame : VideoFrame = await asyncio .wait_for (
92
- self .track .recv (), self .webrtc_peer_timeout
93
- )
94
- except (asyncio .TimeoutError , MediaStreamError ):
95
- logger .info (
96
- "Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
97
- )
98
- self .close ()
99
- raise MediaStreamError
100
- img = frame .to_ndarray (format = "bgr24" )
101
-
102
- dropped = 0
103
- async with async_lock (lock = self .to_inference_lock , pool = self ._pool ):
104
- self .to_inference_queue .appendleft (img )
105
- while self ._track_active and not self .from_inference_queue :
87
+ while self ._track_active :
106
88
try :
107
89
frame : VideoFrame = await asyncio .wait_for (
108
90
self .track .recv (), self .webrtc_peer_timeout
109
91
)
110
92
except (asyncio .TimeoutError , MediaStreamError ):
111
- self .close ()
112
93
logger .info (
113
94
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
114
95
)
96
+ self .close ()
115
97
raise MediaStreamError
116
- dropped += 1
117
- async with async_lock (lock = self .from_inference_lock , pool = self ._pool ):
118
- res = self .from_inference_queue .pop ()
119
98
120
- logger .debug ("Dropping %s every inference" , dropped )
99
+ await self .to_inference_queue .async_put (frame )
100
+
101
+ from_inference_queue_empty = await self .from_inference_queue .async_empty ()
102
+ if not from_inference_queue_empty :
103
+ break
104
+
105
+ while self ._track_active :
106
+ try :
107
+ res : np .ndarray = await asyncio .wait_for (
108
+ self .from_inference_queue .async_get (), self .webrtc_peer_timeout
109
+ )
110
+ break
111
+ except asyncio .TimeoutError :
112
+ continue
113
+ if not self ._track_active :
114
+ logger .info (
115
+ "Received close request while waiting to receive frames from inference pipeline; assuming termination."
116
+ )
117
+ raise MediaStreamError
118
+
121
119
new_frame = VideoFrame .from_ndarray (res , format = "bgr24" )
122
120
new_frame .pts = frame .pts
123
121
new_frame .time_base = frame .time_base
@@ -132,31 +130,35 @@ class WebRTCVideoFrameProducer(VideoFrameProducer):
132
130
)
133
131
def __init__ (
134
132
self ,
135
- to_inference_queue : deque ,
136
- to_inference_lock : Lock ,
133
+ to_inference_queue : "SyncAsyncQueue[VideoFrame]" ,
137
134
stop_event : Event ,
138
135
webrtc_video_transform_track : VideoTransformTrack ,
139
136
):
140
- self .to_inference_queue : deque = to_inference_queue
141
- self .to_inference_lock : Lock = to_inference_lock
137
+ self .to_inference_queue : "SyncAsyncQueue[VideoFrame]" = to_inference_queue
142
138
self ._stop_event = stop_event
143
139
self ._w : Optional [int ] = None
144
140
self ._h : Optional [int ] = None
145
141
self ._video_transform_track = webrtc_video_transform_track
146
142
self ._is_opened = True
147
143
148
144
def grab (self ) -> bool :
149
- return self ._is_opened
145
+ if self ._stop_event .is_set ():
146
+ logger .info ("Received termination signal, closing." )
147
+ self ._is_opened = False
148
+ return False
149
+
150
+ self .to_inference_queue .sync_get ()
151
+ return True
150
152
151
153
def retrieve (self ) -> Tuple [bool , np .ndarray ]:
152
- while not self ._stop_event .is_set () and not self .to_inference_queue :
153
- time .sleep (0.1 )
154
154
if self ._stop_event .is_set ():
155
155
logger .info ("Received termination signal, closing." )
156
156
self ._is_opened = False
157
157
return False , None
158
- with self .to_inference_lock :
159
- img = self .to_inference_queue .pop ()
158
+
159
+ frame : VideoFrame = self .to_inference_queue .sync_get ()
160
+ img = frame .to_ndarray (format = "bgr24" )
161
+
160
162
return True , img
161
163
162
164
def release (self ):
@@ -189,19 +191,17 @@ def __init__(self, video_transform_track: VideoTransformTrack, *args, **kwargs):
189
191
190
192
async def init_rtc_peer_connection (
191
193
webrtc_offer : WebRTCOffer ,
192
- to_inference_queue : Deque ,
193
- to_inference_lock : Lock ,
194
- from_inference_queue : Deque ,
195
- from_inference_lock : Lock ,
194
+ to_inference_queue : "SyncAsyncQueue[VideoFrame]" ,
195
+ from_inference_queue : "SyncAsyncQueue[np.ndarray]" ,
196
196
webrtc_peer_timeout : float ,
197
197
feedback_stop_event : Event ,
198
+ asyncio_loop : asyncio .AbstractEventLoop ,
198
199
webcam_fps : Optional [float ] = None ,
199
200
) -> RTCPeerConnectionWithFPS :
200
201
video_transform_track = VideoTransformTrack (
201
- to_inference_lock = to_inference_lock ,
202
202
to_inference_queue = to_inference_queue ,
203
- from_inference_lock = from_inference_lock ,
204
203
from_inference_queue = from_inference_queue ,
204
+ asyncio_loop = asyncio_loop ,
205
205
webrtc_peer_timeout = webrtc_peer_timeout ,
206
206
webcam_fps = webcam_fps ,
207
207
)
0 commit comments