Skip to content

Commit 1fffd02

Browse files
authored
Allow option to use the swscale library for color conversion instead of filtergraph (#205)
1 parent 6bebade commit 1fffd02

File tree

12 files changed

+465
-74
lines changed

12 files changed

+465
-74
lines changed

.github/workflows/cpp_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,5 @@ jobs:
6565
Torch_DIR="${TORCH_PATH}/share/cmake/Torch"
6666
cmake .. -DTorch_DIR=$Torch_DIR -DCMAKE_BUILD_TYPE=Debug -DBUILD_TESTS=ON -DCMAKE_VERBOSE_MAKEFILE=ON
6767
cmake --build .
68-
ctest
68+
ctest --output-on-failure
6969
popd

benchmarks/decoders/benchmark_decoders.py

Lines changed: 123 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchcodec.decoders import SimpleVideoDecoder
1717

1818
from torchcodec.decoders._core import (
19-
add_video_stream,
19+
_add_video_stream,
2020
create_from_file,
2121
get_frames_at_indices,
2222
get_json_metadata,
@@ -86,38 +86,63 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
8686
class TVNewAPIDecoderWithBackend(AbstractDecoder):
8787
def __init__(self, backend):
8888
self._backend = backend
89+
self._print_each_iteration_time = False
8990
import torchvision # noqa: F401
9091

9192
self.torchvision = torchvision
9293

9394
def get_frames_from_video(self, video_file, pts_list):
95+
start = timeit.default_timer()
9496
self.torchvision.set_video_backend(self._backend)
9597
reader = self.torchvision.io.VideoReader(video_file, "video")
98+
create_done = timeit.default_timer()
9699
frames = []
97100
for pts in pts_list:
98101
reader.seek(pts)
99102
frame = next(reader)
100103
frames.append(frame["data"].permute(1, 2, 0))
104+
frames_done = timeit.default_timer()
105+
if self._print_each_iteration_time:
106+
create_duration = 1000 * round(create_done - start, 3)
107+
frames_duration = 1000 * round(frames_done - create_done, 3)
108+
total_duration = 1000 * round(frames_done - start, 3)
109+
print(f"TV: {create_duration=} {frames_duration=} {total_duration=}")
101110
return frames
102111

103112
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
113+
start = timeit.default_timer()
104114
self.torchvision.set_video_backend(self._backend)
105115
reader = self.torchvision.io.VideoReader(video_file, "video")
116+
create_done = timeit.default_timer()
106117
frames = []
107118
for _ in range(numFramesToDecode):
108119
frame = next(reader)
109120
frames.append(frame["data"].permute(1, 2, 0))
121+
frames_done = timeit.default_timer()
122+
123+
if self._print_each_iteration_time:
124+
create_duration = 1000 * round(create_done - start, 3)
125+
frames_duration = 1000 * round(frames_done - create_done, 3)
126+
total_duration = 1000 * round(frames_done - start, 3)
127+
print(
128+
f"TV: consecutive: {create_duration=} {frames_duration=} {total_duration=} {frames[0].shape=}"
129+
)
110130
return frames
111131

112132

113-
class TorchCodecDecoderNonCompiledWithOptions(AbstractDecoder):
114-
def __init__(self, num_threads=None):
133+
class TorchcodecNonCompiledWithOptions(AbstractDecoder):
134+
def __init__(self, num_threads=None, color_conversion_library=None):
115135
self._print_each_iteration_time = False
116-
self._num_threads = num_threads
136+
self._num_threads = int(num_threads) if num_threads else None
137+
self._color_conversion_library = color_conversion_library
117138

118139
def get_frames_from_video(self, video_file, pts_list):
119140
decoder = create_from_file(video_file)
120-
add_video_stream(decoder, num_threads=self._num_threads)
141+
_add_video_stream(
142+
decoder,
143+
num_threads=self._num_threads,
144+
color_conversion_library=self._color_conversion_library,
145+
)
121146
frames = []
122147
times = []
123148
for pts in pts_list:
@@ -127,35 +152,57 @@ def get_frames_from_video(self, video_file, pts_list):
127152
end = timeit.default_timer()
128153
times.append(round(end - start, 3))
129154
frames.append(frame)
155+
130156
if self._print_each_iteration_time:
131157
print("torchcodec times=", times, sum(times))
132158
return frames
133159

134160
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
161+
create_time = timeit.default_timer()
135162
decoder = create_from_file(video_file)
136-
add_video_stream(decoder, num_threads=self._num_threads)
163+
add_stream_time = timeit.default_timer()
164+
_add_video_stream(
165+
decoder,
166+
num_threads=self._num_threads,
167+
color_conversion_library=self._color_conversion_library,
168+
)
137169
frames = []
138170
times = []
171+
frames_time = timeit.default_timer()
139172
for _ in range(numFramesToDecode):
140173
start = timeit.default_timer()
141174
frame = get_next_frame(decoder)
142175
end = timeit.default_timer()
143176
times.append(round(end - start, 3))
144177
frames.append(frame)
178+
145179
if self._print_each_iteration_time:
180+
done_time = timeit.default_timer()
181+
create_duration = 1000 * round(add_stream_time - create_time, 3)
182+
add_stream_duration = 1000 * round(frames_time - add_stream_time, 3)
183+
frames_duration = 1000 * round(done_time - frames_time, 3)
184+
total_duration = 1000 * round(done_time - create_time, 3)
185+
print(
186+
f"{numFramesToDecode=} {create_duration=} {add_stream_duration=} {frames_duration=} {total_duration=} {frames[0][0].shape=}"
187+
)
146188
print("torchcodec times=", times, sum(times))
147189
return frames
148190

149191

150-
class TorchCodecDecoderNonCompiledBatch(AbstractDecoder):
151-
def __init__(self, num_threads=None):
192+
class TorchCodecNonCompiledBatch(AbstractDecoder):
193+
def __init__(self, num_threads=None, color_conversion_library=None):
152194
self._print_each_iteration_time = False
153-
self._num_threads = num_threads
195+
self._num_threads = int(num_threads) if num_threads else None
196+
self._color_conversion_library = color_conversion_library
154197

155198
def get_frames_from_video(self, video_file, pts_list):
156199
decoder = create_from_file(video_file)
157200
scan_all_streams_to_update_metadata(decoder)
158-
add_video_stream(decoder, num_threads=self._num_threads)
201+
_add_video_stream(
202+
decoder,
203+
num_threads=self._num_threads,
204+
color_conversion_library=self._color_conversion_library,
205+
)
159206
metadata = json.loads(get_json_metadata(decoder))
160207
average_fps = metadata["averageFps"]
161208
best_video_stream = metadata["bestVideoStreamIndex"]
@@ -169,7 +216,11 @@ def get_frames_from_video(self, video_file, pts_list):
169216
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
170217
decoder = create_from_file(video_file)
171218
scan_all_streams_to_update_metadata(decoder)
172-
add_video_stream(decoder, num_threads=self._num_threads)
219+
_add_video_stream(
220+
decoder,
221+
num_threads=self._num_threads,
222+
color_conversion_library=self._color_conversion_library,
223+
)
173224
metadata = json.loads(get_json_metadata(decoder))
174225
best_video_stream = metadata["bestVideoStreamIndex"]
175226
frames = []
@@ -191,13 +242,13 @@ def compiled_next(decoder):
191242
return get_next_frame(decoder)
192243

193244

194-
class TorchCodecDecoderCompiled(AbstractDecoder):
245+
class TorchcodecCompiled(AbstractDecoder):
195246
def __init__(self):
196247
pass
197248

198249
def get_frames_from_video(self, video_file, pts_list):
199250
decoder = create_from_file(video_file)
200-
add_video_stream(decoder)
251+
_add_video_stream(decoder)
201252
frames = []
202253
for pts in pts_list:
203254
frame = compiled_seek_and_next(decoder, pts)
@@ -206,7 +257,7 @@ def get_frames_from_video(self, video_file, pts_list):
206257

207258
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
208259
decoder = create_from_file(video_file)
209-
add_video_stream(decoder)
260+
_add_video_stream(decoder)
210261
frames = []
211262
for _ in range(numFramesToDecode):
212263
frame = compiled_next(decoder)
@@ -259,7 +310,7 @@ def get_test_resource_path(filename: str) -> str:
259310

260311
def create_torchcodec_decoder_from_file(video_file):
261312
video_decoder = create_from_file(video_file)
262-
add_video_stream(video_decoder)
313+
_add_video_stream(video_decoder)
263314
get_next_frame(video_decoder)
264315
return video_decoder
265316

@@ -294,9 +345,13 @@ def main() -> None:
294345
)
295346
parser.add_argument(
296347
"--decoders",
297-
help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec. torchcodec_batch means torchcodec using batch methods.",
348+
help=(
349+
"Comma-separated list of decoders to benchmark. "
350+
"Choices are torchcodec, torchaudio, torchvision, decord, tcoptions:num_threads=1+color_conversion_library=filtergraph, torchcodec_compiled"
351+
"For torchcodec, you can specify options with tcoptions:<plus-separated-options>. "
352+
),
298353
type=str,
299-
default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled,torchcodec_batch",
354+
default="decord,tcoptions:,torchvision,torchaudio,torchcodec_compiled,tcoptions:num_threads=1",
300355
)
301356

302357
args = parser.parse_args()
@@ -306,38 +361,51 @@ def main() -> None:
306361
num_uniform_samples = 10
307362

308363
decoder_dict = {}
309-
if "decord" in decoders:
310-
decoder_dict["DecordNonBatchDecoderAccurateSeek"] = (
311-
DecordNonBatchDecoderAccurateSeek()
312-
)
313-
if "torchcodec" in decoders:
314-
decoder_dict["TorchCodecDecoderNonCompiled"] = (
315-
TorchCodecDecoderNonCompiledWithOptions()
316-
)
317-
if "torchcodec_compiled" in decoders:
318-
decoder_dict["TorchCodecDecoderCompiled"] = TorchCodecDecoderCompiled()
319-
if "torchcodec1" in decoders:
320-
decoder_dict["TCNonCompiled:ffmpeg_thread_count=1"] = (
321-
TorchCodecDecoderNonCompiledWithOptions(num_threads=1)
322-
)
323-
# We don't compare TorchVision's "pyav" backend because it doesn't support
324-
# accurate seeks.
325-
if "torchvision" in decoders:
326-
decoder_dict["TVNewAPIDecoderWithBackendVideoReader"] = (
327-
TVNewAPIDecoderWithBackend("video_reader")
328-
)
329-
if "torchaudio" in decoders:
330-
decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder()
331-
if "torchcodec_batch" in decoders:
332-
decoder_dict["TorchCodecDecoderNonCompiledBatch"] = (
333-
TorchCodecDecoderNonCompiledBatch()
334-
)
335-
336-
decoder_dict["TVNewAPIDecoderWithBackendVideoReader"]
364+
for decoder in decoders:
365+
if decoder == "decord":
366+
decoder_dict["DecordNonBatchDecoderAccurateSeek"] = (
367+
DecordNonBatchDecoderAccurateSeek()
368+
)
369+
elif decoder == "torchcodec":
370+
decoder_dict["TorchCodecNonCompiled"] = TorchcodecNonCompiledWithOptions()
371+
elif decoder == "torchcodec_compiled":
372+
decoder_dict["TorchcodecCompiled"] = TorchcodecCompiled()
373+
elif decoder == "torchvision":
374+
decoder_dict["TVNewAPIDecoderWithBackendVideoReader"] = (
375+
# We don't compare TorchVision's "pyav" backend because it doesn't support
376+
# accurate seeks.
377+
TVNewAPIDecoderWithBackend("video_reader")
378+
)
379+
elif decoder == "torchaudio":
380+
decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder()
381+
elif decoder.startswith("tcbatchoptions:"):
382+
options = decoder[len("tcbatchoptions:") :]
383+
kwargs_dict = {}
384+
for item in options.split("+"):
385+
if item.strip() == "":
386+
continue
387+
k, v = item.split("=")
388+
kwargs_dict[k] = v
389+
decoder_dict["TorchCodecNonCompiledBatch:" + options] = (
390+
TorchCodecNonCompiledBatch(**kwargs_dict)
391+
)
392+
elif decoder.startswith("tcoptions:"):
393+
options = decoder[len("tcoptions:") :]
394+
kwargs_dict = {}
395+
for item in options.split("+"):
396+
if item.strip() == "":
397+
continue
398+
k, v = item.split("=")
399+
kwargs_dict[k] = v
400+
decoder_dict["TorchcodecNonCompiled:" + options] = (
401+
TorchcodecNonCompiledWithOptions(**kwargs_dict)
402+
)
337403

338404
results = []
339405
for decoder_name, decoder in decoder_dict.items():
340406
for video_path in args.bm_video_paths.split(","):
407+
# We only use the SimpleVideoDecoder to get the metadata and get
408+
# the list of PTS values to seek to.
341409
simple_decoder = SimpleVideoDecoder(video_path)
342410
duration = simple_decoder.metadata.duration_seconds
343411
pts_list = [
@@ -365,7 +433,7 @@ def main() -> None:
365433
min_run_time=args.bm_video_speed_min_run_seconds
366434
)
367435
)
368-
for num_consecutive_nexts in [1, 10, 100]:
436+
for num_consecutive_nexts in [1, 10]:
369437
consecutive_frames_result = benchmark.Timer(
370438
stmt="decoder.get_consecutive_frames_from_video(video_file, consecutive_frames_to_extract)",
371439
globals={
@@ -385,17 +453,24 @@ def main() -> None:
385453

386454
first_video_path = args.bm_video_paths.split(",")[0]
387455
if args.bm_video_creation:
456+
simple_decoder = SimpleVideoDecoder(first_video_path)
457+
metadata = simple_decoder.metadata
458+
metadata_string = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
388459
creation_result = benchmark.Timer(
389460
stmt="create_torchcodec_decoder_from_file(video_file)",
390461
globals={
391462
"video_file": first_video_path,
392463
"create_torchcodec_decoder_from_file": create_torchcodec_decoder_from_file,
393464
},
394465
label=f"video={first_video_path} {metadata_string}",
395-
sub_label="TorchCodecDecoderNonCompiled",
466+
sub_label="TorchcodecNonCompiled",
396467
description="create()+next()",
397468
)
398-
results.append(creation_result.blocked_autorange(min_run_time=10.0))
469+
results.append(
470+
creation_result.blocked_autorange(
471+
min_run_time=2.0,
472+
)
473+
)
399474
compare = benchmark.Compare(results)
400475
compare.print()
401476

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ else()
8585
libavformat
8686
libavcodec
8787
libavutil
88+
libswscale
8889
)
8990

9091
# Split libavcodec's version string by '.' and convert it to a list

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ extern "C" {
2222
#include <libavutil/opt.h>
2323
#include <libavutil/pixfmt.h>
2424
#include <libavutil/version.h>
25+
#include <libswscale/swscale.h>
2526
}
2627

2728
namespace facebook::torchcodec {
@@ -38,6 +39,15 @@ struct Deleterp {
3839
}
3940
};
4041

42+
template <typename T, typename R, R (*Fn)(T*)>
43+
struct Deleter {
44+
inline void operator()(T* p) const {
45+
if (p) {
46+
Fn(p);
47+
}
48+
}
49+
};
50+
4151
// Unique pointers for FFMPEG structures.
4252
using UniqueAVFormatContext = std::unique_ptr<
4353
AVFormatContext,
@@ -57,6 +67,8 @@ using UniqueAVFilterInOut = std::unique_ptr<
5767
Deleterp<AVFilterInOut, void, avfilter_inout_free>>;
5868
using UniqueAVIOContext = std::
5969
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
70+
using UniqueSwsContext =
71+
std::unique_ptr<SwsContext, Deleter<SwsContext, void, sws_freeContext>>;
6072

6173
// av_find_best_stream is not const-correct before commit:
6274
// https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c

0 commit comments

Comments
 (0)