1- from typing import List , Literal , Optional
1+ from typing import Literal , Optional
22
33import torch
44
5- from torchcodec import Frame , FrameBatch
6- from torchcodec .decoders import VideoDecoder
5+ from torchcodec import FrameBatch
6+ from torchcodec .decoders . _core import get_frames_by_pts
77from torchcodec .samplers ._common import (
8- _chunk_list ,
8+ _make_5d_framebatch ,
99 _POLICY_FUNCTION_TYPE ,
1010 _POLICY_FUNCTIONS ,
11- _to_framebatch ,
1211 _validate_common_params ,
1312)
1413
@@ -147,51 +146,6 @@ def _build_all_clips_timestamps(
147146 return all_clips_timestamps
148147
149148
150- def _decode_all_clips_timestamps (
151- decoder : VideoDecoder , all_clips_timestamps : list [float ], num_frames_per_clip : int
152- ) -> list [FrameBatch ]:
153- # This is 99% the same as _decode_all_clips_indices. The only change is the
154- # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx)
155-
156- all_clips_timestamps_sorted , argsort = zip (
157- * sorted (
158- (frame_index , i ) for (i , frame_index ) in enumerate (all_clips_timestamps )
159- )
160- )
161- previous_decoded_frame = None
162- all_decoded_frames = [None ] * len (all_clips_timestamps )
163- for i , j in enumerate (argsort ):
164- frame_pts_seconds = all_clips_timestamps_sorted [i ]
165- if (
166- previous_decoded_frame is not None # then we know i > 0
167- and frame_pts_seconds == all_clips_timestamps_sorted [i - 1 ]
168- ):
169- # Avoid decoding the same frame twice.
170- # Unfortunatly this is unlikely to lead to speed-up as-is: it's
171- # pretty unlikely that 2 pts will be the same since pts are float
172- # contiguous values. Theoretically the dedup can still happen, but
173- # it would be much more efficient to implement it at the frame index
174- # level. We should do that once we implement that in C++.
175- # See also https://github.com/pytorch/torchcodec/issues/256.
176- #
177- # IMPORTANT: this is only correct because a copy of the frame will
178- # happen within `_to_framebatch` when we call torch.stack.
179- # If a copy isn't made, the same underlying memory will be used for
180- # the 2 consecutive frames. When we re-write this, we should make
181- # sure to explicitly copy the data.
182- decoded_frame = previous_decoded_frame
183- else :
184- decoded_frame = decoder .get_frame_displayed_at (seconds = frame_pts_seconds )
185- previous_decoded_frame = decoded_frame
186- all_decoded_frames [j ] = decoded_frame
187-
188- all_clips : list [list [Frame ]] = _chunk_list (
189- all_decoded_frames , chunk_size = num_frames_per_clip
190- )
191-
192- return [_to_framebatch (clip ) for clip in all_clips ]
193-
194-
195149def _generic_time_based_sampler (
196150 kind : Literal ["random" , "regular" ],
197151 decoder ,
@@ -204,7 +158,7 @@ def _generic_time_based_sampler(
204158 sampling_range_start : Optional [float ],
205159 sampling_range_end : Optional [float ], # interval is [start, end).
206160 policy : str = "repeat_last" ,
207- ) -> List [ FrameBatch ] :
161+ ) -> FrameBatch :
208162 # Note: *everywhere*, sampling_range_end denotes the upper bound of where a
209163 # clip can start. This is an *open* upper bound, i.e. we will make sure no
210164 # clip starts exactly at (or above) sampling_range_end.
@@ -246,6 +200,7 @@ def _generic_time_based_sampler(
246200 sampling_range_end , # excluded
247201 seconds_between_clip_starts ,
248202 )
203+ num_clips = len (clip_start_seconds )
249204
250205 all_clips_timestamps = _build_all_clips_timestamps (
251206 clip_start_seconds = clip_start_seconds ,
@@ -255,9 +210,17 @@ def _generic_time_based_sampler(
255210 policy_fun = _POLICY_FUNCTIONS [policy ],
256211 )
257212
258- return _decode_all_clips_timestamps (
259- decoder ,
260- all_clips_timestamps = all_clips_timestamps ,
213+ # TODO: Use public method of decoder, when it exists
214+ frames , pts_seconds , duration_seconds = get_frames_by_pts (
215+ decoder ._decoder ,
216+ stream_index = decoder .stream_index ,
217+ timestamps = all_clips_timestamps ,
218+ )
219+ return _make_5d_framebatch (
220+ data = frames ,
221+ pts_seconds = pts_seconds ,
222+ duration_seconds = duration_seconds ,
223+ num_clips = num_clips ,
261224 num_frames_per_clip = num_frames_per_clip ,
262225 )
263226
@@ -272,7 +235,7 @@ def clips_at_random_timestamps(
272235 sampling_range_start : Optional [float ] = None ,
273236 sampling_range_end : Optional [float ] = None , # interval is [start, end).
274237 policy : str = "repeat_last" ,
275- ) -> List [ FrameBatch ] :
238+ ) -> FrameBatch :
276239 return _generic_time_based_sampler (
277240 kind = "random" ,
278241 decoder = decoder ,
@@ -296,7 +259,7 @@ def clips_at_regular_timestamps(
296259 sampling_range_start : Optional [float ] = None ,
297260 sampling_range_end : Optional [float ] = None , # interval is [start, end).
298261 policy : str = "repeat_last" ,
299- ) -> List [ FrameBatch ] :
262+ ) -> FrameBatch :
300263 return _generic_time_based_sampler (
301264 kind = "regular" ,
302265 decoder = decoder ,
0 commit comments