Skip to content

Commit ebc809c

Browse files
authored
Speed-up time-based samplers by 20X and index-based by 1.5X (#284)
1 parent 1bdf928 commit ebc809c

File tree

4 files changed

+85
-156
lines changed

4 files changed

+85
-156
lines changed

src/torchcodec/samplers/_common.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, Union
22

3-
import torch
4-
from torchcodec import Frame, FrameBatch
3+
from torch import Tensor
4+
from torchcodec import FrameBatch
55

66
_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
77

@@ -42,22 +42,6 @@ def _error_policy(
4242
}
4343

4444

45-
def _chunk_list(lst, chunk_size):
46-
# return list of sublists of length chunk_size
47-
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
48-
49-
50-
def _to_framebatch(frames: list[Frame]) -> FrameBatch:
51-
# IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and
52-
# _decode_all_clips_timestamps
53-
data = torch.stack([frame.data for frame in frames])
54-
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames])
55-
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames])
56-
return FrameBatch(
57-
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds
58-
)
59-
60-
6145
def _validate_common_params(*, decoder, num_frames_per_clip, policy):
6246
if len(decoder) < 1:
6347
raise ValueError(
@@ -72,3 +56,19 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
7256
raise ValueError(
7357
f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}."
7458
)
59+
60+
61+
def _make_5d_framebatch(
62+
*,
63+
data: Tensor,
64+
pts_seconds: Tensor,
65+
duration_seconds: Tensor,
66+
num_clips: int,
67+
num_frames_per_clip: int,
68+
) -> FrameBatch:
69+
last_3_dims = data.shape[-3:]
70+
return FrameBatch(
71+
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
72+
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
73+
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
74+
)

src/torchcodec/samplers/_index_based.py

Lines changed: 19 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from typing import List, Literal, Optional
1+
from typing import Literal, Optional
22

33
import torch
44

5-
from torchcodec import Frame, FrameBatch
5+
from torchcodec import FrameBatch
66
from torchcodec.decoders import VideoDecoder
7+
from torchcodec.decoders._core import get_frames_at_indices
78
from torchcodec.samplers._common import (
8-
_chunk_list,
9+
_make_5d_framebatch,
910
_POLICY_FUNCTION_TYPE,
1011
_POLICY_FUNCTIONS,
11-
_to_framebatch,
1212
_validate_common_params,
1313
)
1414

@@ -117,51 +117,6 @@ def _build_all_clips_indices(
117117
return all_clips_indices
118118

119119

120-
def _decode_all_clips_indices(
121-
decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int
122-
) -> list[FrameBatch]:
123-
# This takes the list of all the frames to decode (in arbitrary order),
124-
# decode all the frames, and then packs them into clips of length
125-
# num_frames_per_clip.
126-
#
127-
# To avoid backwards seeks (which are slow), we:
128-
# - sort all the frame indices to be decoded
129-
# - dedup them
130-
# - decode all unique frames in sorted order
131-
# - re-assemble the decoded frames back to their original order
132-
#
133-
# TODO: Write this in C++ so we can avoid the copies that happen in `_to_framebatch`
134-
135-
all_clips_indices_sorted, argsort = zip(
136-
*sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices))
137-
)
138-
previous_decoded_frame = None
139-
all_decoded_frames = [None] * len(all_clips_indices)
140-
for i, j in enumerate(argsort):
141-
frame_index = all_clips_indices_sorted[i]
142-
if (
143-
previous_decoded_frame is not None # then we know i > 0
144-
and frame_index == all_clips_indices_sorted[i - 1]
145-
):
146-
# Avoid decoding the same frame twice.
147-
# IMPORTANT: this is only correct because a copy of the frame will
148-
# happen within `_to_framebatch` when we call torch.stack.
149-
# If a copy isn't made, the same underlying memory will be used for
150-
# the 2 consecutive frames. When we re-write this, we should make
151-
# sure to explicitly copy the data.
152-
decoded_frame = previous_decoded_frame
153-
else:
154-
decoded_frame = decoder.get_frame_at(index=frame_index)
155-
previous_decoded_frame = decoded_frame
156-
all_decoded_frames[j] = decoded_frame
157-
158-
all_clips: list[list[Frame]] = _chunk_list(
159-
all_decoded_frames, chunk_size=num_frames_per_clip
160-
)
161-
162-
return [_to_framebatch(clip) for clip in all_clips]
163-
164-
165120
def _generic_index_based_sampler(
166121
kind: Literal["random", "regular"],
167122
decoder: VideoDecoder,
@@ -174,7 +129,7 @@ def _generic_index_based_sampler(
174129
# Important note: sampling_range_end defines the upper bound of where a clip
175130
# can *start*, not where a clip can end.
176131
policy: Literal["repeat_last", "wrap", "error"],
177-
) -> List[FrameBatch]:
132+
) -> FrameBatch:
178133

179134
_validate_common_params(
180135
decoder=decoder,
@@ -221,9 +176,18 @@ def _generic_index_based_sampler(
221176
num_frames_in_video=len(decoder),
222177
policy_fun=_POLICY_FUNCTIONS[policy],
223178
)
224-
return _decode_all_clips_indices(
225-
decoder,
226-
all_clips_indices=all_clips_indices,
179+
180+
# TODO: Use public method of decoder, when it exists
181+
frames, pts_seconds, duration_seconds = get_frames_at_indices(
182+
decoder._decoder,
183+
stream_index=decoder.stream_index,
184+
frame_indices=all_clips_indices,
185+
)
186+
return _make_5d_framebatch(
187+
data=frames,
188+
pts_seconds=pts_seconds,
189+
duration_seconds=duration_seconds,
190+
num_clips=num_clips,
227191
num_frames_per_clip=num_frames_per_clip,
228192
)
229193

@@ -237,7 +201,7 @@ def clips_at_random_indices(
237201
sampling_range_start: int = 0,
238202
sampling_range_end: Optional[int] = None, # interval is [start, end).
239203
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
240-
) -> List[FrameBatch]:
204+
) -> FrameBatch:
241205
return _generic_index_based_sampler(
242206
kind="random",
243207
decoder=decoder,
@@ -259,7 +223,7 @@ def clips_at_regular_indices(
259223
sampling_range_start: int = 0,
260224
sampling_range_end: Optional[int] = None, # interval is [start, end).
261225
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
262-
) -> List[FrameBatch]:
226+
) -> FrameBatch:
263227

264228
return _generic_index_based_sampler(
265229
kind="regular",

src/torchcodec/samplers/_time_based.py

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from typing import List, Literal, Optional
1+
from typing import Literal, Optional
22

33
import torch
44

5-
from torchcodec import Frame, FrameBatch
6-
from torchcodec.decoders import VideoDecoder
5+
from torchcodec import FrameBatch
6+
from torchcodec.decoders._core import get_frames_by_pts
77
from torchcodec.samplers._common import (
8-
_chunk_list,
8+
_make_5d_framebatch,
99
_POLICY_FUNCTION_TYPE,
1010
_POLICY_FUNCTIONS,
11-
_to_framebatch,
1211
_validate_common_params,
1312
)
1413

@@ -147,51 +146,6 @@ def _build_all_clips_timestamps(
147146
return all_clips_timestamps
148147

149148

150-
def _decode_all_clips_timestamps(
151-
decoder: VideoDecoder, all_clips_timestamps: list[float], num_frames_per_clip: int
152-
) -> list[FrameBatch]:
153-
# This is 99% the same as _decode_all_clips_indices. The only change is the
154-
# call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx)
155-
156-
all_clips_timestamps_sorted, argsort = zip(
157-
*sorted(
158-
(frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps)
159-
)
160-
)
161-
previous_decoded_frame = None
162-
all_decoded_frames = [None] * len(all_clips_timestamps)
163-
for i, j in enumerate(argsort):
164-
frame_pts_seconds = all_clips_timestamps_sorted[i]
165-
if (
166-
previous_decoded_frame is not None # then we know i > 0
167-
and frame_pts_seconds == all_clips_timestamps_sorted[i - 1]
168-
):
169-
# Avoid decoding the same frame twice.
170-
# Unfortunatly this is unlikely to lead to speed-up as-is: it's
171-
# pretty unlikely that 2 pts will be the same since pts are float
172-
# contiguous values. Theoretically the dedup can still happen, but
173-
# it would be much more efficient to implement it at the frame index
174-
# level. We should do that once we implement that in C++.
175-
# See also https://github.com/pytorch/torchcodec/issues/256.
176-
#
177-
# IMPORTANT: this is only correct because a copy of the frame will
178-
# happen within `_to_framebatch` when we call torch.stack.
179-
# If a copy isn't made, the same underlying memory will be used for
180-
# the 2 consecutive frames. When we re-write this, we should make
181-
# sure to explicitly copy the data.
182-
decoded_frame = previous_decoded_frame
183-
else:
184-
decoded_frame = decoder.get_frame_displayed_at(seconds=frame_pts_seconds)
185-
previous_decoded_frame = decoded_frame
186-
all_decoded_frames[j] = decoded_frame
187-
188-
all_clips: list[list[Frame]] = _chunk_list(
189-
all_decoded_frames, chunk_size=num_frames_per_clip
190-
)
191-
192-
return [_to_framebatch(clip) for clip in all_clips]
193-
194-
195149
def _generic_time_based_sampler(
196150
kind: Literal["random", "regular"],
197151
decoder,
@@ -204,7 +158,7 @@ def _generic_time_based_sampler(
204158
sampling_range_start: Optional[float],
205159
sampling_range_end: Optional[float], # interval is [start, end).
206160
policy: str = "repeat_last",
207-
) -> List[FrameBatch]:
161+
) -> FrameBatch:
208162
# Note: *everywhere*, sampling_range_end denotes the upper bound of where a
209163
# clip can start. This is an *open* upper bound, i.e. we will make sure no
210164
# clip starts exactly at (or above) sampling_range_end.
@@ -246,6 +200,7 @@ def _generic_time_based_sampler(
246200
sampling_range_end, # excluded
247201
seconds_between_clip_starts,
248202
)
203+
num_clips = len(clip_start_seconds)
249204

250205
all_clips_timestamps = _build_all_clips_timestamps(
251206
clip_start_seconds=clip_start_seconds,
@@ -255,9 +210,17 @@ def _generic_time_based_sampler(
255210
policy_fun=_POLICY_FUNCTIONS[policy],
256211
)
257212

258-
return _decode_all_clips_timestamps(
259-
decoder,
260-
all_clips_timestamps=all_clips_timestamps,
213+
# TODO: Use public method of decoder, when it exists
214+
frames, pts_seconds, duration_seconds = get_frames_by_pts(
215+
decoder._decoder,
216+
stream_index=decoder.stream_index,
217+
timestamps=all_clips_timestamps,
218+
)
219+
return _make_5d_framebatch(
220+
data=frames,
221+
pts_seconds=pts_seconds,
222+
duration_seconds=duration_seconds,
223+
num_clips=num_clips,
261224
num_frames_per_clip=num_frames_per_clip,
262225
)
263226

@@ -272,7 +235,7 @@ def clips_at_random_timestamps(
272235
sampling_range_start: Optional[float] = None,
273236
sampling_range_end: Optional[float] = None, # interval is [start, end).
274237
policy: str = "repeat_last",
275-
) -> List[FrameBatch]:
238+
) -> FrameBatch:
276239
return _generic_time_based_sampler(
277240
kind="random",
278241
decoder=decoder,
@@ -296,7 +259,7 @@ def clips_at_regular_timestamps(
296259
sampling_range_start: Optional[float] = None,
297260
sampling_range_end: Optional[float] = None, # interval is [start, end).
298261
policy: str = "repeat_last",
299-
) -> List[FrameBatch]:
262+
) -> FrameBatch:
300263
return _generic_time_based_sampler(
301264
kind="regular",
302265
decoder=decoder,

0 commit comments

Comments
 (0)