Skip to content

Commit 8fee167

Browse files
ahmadsharif1facebook-github-bot
authored andcommitted
[torchcodec] Add support for Nvidia GPU Decoding (#137)
Summary: Pull Request resolved: #137 Pull Request resolved: #58 X-link: #58 1. Add CUDA support to VideoDecoder.cpp. This is done by checking what device is passed into the options and using CUDA if the device type is cuda. 2. Add -DENABLE_CUDA flag in cmake. 3. Check ENABLE_CUDA environment variable in setup.py and pass it down to cmake if it is present. 4. Add a unit test to demonstrate that CUDA decoding does work. This uses a different tensor than the one from CPU decoding because hardware decoding is intrinsically a bit inaccurate. I generated the reference tensor by dumping the tensor from the GPU on my devVM. It is possible different Nvidia hardware show different outputs. How to test this in a more robust way is TBD. 5. Added a new parameter for cuda device index for `add_video_stream`. If this is present, we will use it to do hardware decoding on a CUDA device. There is a whole bunch of TODOs: 1. Currently GPU utilization is only 7-8% when decoding the video. We need to get this higher. 2. Speed it up compared to CPU implementation. Currently this is slower than CPU decoding even for HD videos (probably because we can't hide the CPU to GPU memcpy). However, decode+resize is faster as the benchmark says. Reviewed By: scotts Differential Revision: D59121006 fbshipit-source-id: da6faa60c8de5d8e6ad90f8897d339c9979005f1
1 parent 927e73c commit 8fee167

18 files changed

+403
-26
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
cmake_minimum_required(VERSION 3.18)
22
project(TorchCodec)
33

4+
option(ENABLE_CUDA "Enable CUDA decoding using NVDEC" OFF)
5+
option(ENABLE_NVTX "Enable NVTX annotations for profiling" OFF)
6+
47
add_subdirectory(src/torchcodec/decoders/_core)
58

69

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,7 @@ guide](CONTRIBUTING.md) for more details.
127127
## License
128128

129129
TorchCodec is released under the [BSD 3 license](./LICENSE).
130+
131+
132+
If you are building with ENABLE_CUDA and/or ENABLE_NVTX please review
133+
[Nvidia licenses](https://docs.nvidia.com/cuda/eula/index.html).

benchmarks/decoders/BenchmarkDecodersMain.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ void runNDecodeIterationsWithCustomOps(
145145
/*height=*/std::nullopt,
146146
/*thread_count=*/std::nullopt,
147147
/*dimension_order=*/std::nullopt,
148-
/*stream_index=*/std::nullopt);
148+
/*stream_index=*/std::nullopt,
149+
/*device_string=*/std::nullopt);
149150

150151
for (double pts : ptsList) {
151152
seekFrameOp.call(decoderTensor, pts);

benchmarks/decoders/gpu_benchmark.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import argparse
2+
import os
3+
import time
4+
5+
import torch.utils.benchmark as benchmark
6+
7+
import torchcodec
8+
from torchvision.transforms import Resize
9+
10+
11+
def transfer_and_resize_frame(frame, device):
12+
# This should be a no-op if the frame is already on the device.
13+
frame = frame.to(device)
14+
frame = Resize((256, 256))(frame)
15+
return frame
16+
17+
18+
def decode_full_video(video_path, decode_device):
19+
decoder = torchcodec.decoders._core.create_from_file(video_path)
20+
num_threads = None
21+
if "cuda" in decode_device:
22+
num_threads = 1
23+
torchcodec.decoders._core.add_video_stream(
24+
decoder, stream_index=0, device_string=decode_device, num_threads=num_threads
25+
)
26+
start_time = time.time()
27+
frame_count = 0
28+
while True:
29+
try:
30+
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder)
31+
# You can do a resize to simulate extra preproc work that happens
32+
# on the GPU by uncommenting the following line:
33+
# frame = transfer_and_resize_frame(frame, decode_device)
34+
35+
frame_count += 1
36+
except Exception as e:
37+
print("EXCEPTION", e)
38+
break
39+
# print(f"current {frame_count=}", flush=True)
40+
end_time = time.time()
41+
elapsed = end_time - start_time
42+
fps = frame_count / (end_time - start_time)
43+
print(
44+
f"****** DECODED full video {decode_device=} {frame_count=} {elapsed=} {fps=}"
45+
)
46+
return frame_count, end_time - start_time
47+
48+
49+
def main():
50+
parser = argparse.ArgumentParser()
51+
parser.add_argument(
52+
"--devices",
53+
default="cuda:0,cpu",
54+
type=str,
55+
help="Comma-separated devices to test decoding on.",
56+
)
57+
parser.add_argument(
58+
"--video",
59+
type=str,
60+
default=os.path.dirname(__file__) + "/../../test/resources/nasa_13013.mp4",
61+
)
62+
parser.add_argument(
63+
"--use_torch_benchmark",
64+
action=argparse.BooleanOptionalAction,
65+
default=True,
66+
help=(
67+
"Use pytorch benchmark to measure decode time with warmup and "
68+
"autorange. Without this we just run one iteration without warmup "
69+
"to measure the cold start time."
70+
),
71+
)
72+
args = parser.parse_args()
73+
video_path = args.video
74+
75+
if not args.use_torch_benchmark:
76+
for device in args.devices.split(","):
77+
print("Testing on", device)
78+
decode_full_video(video_path, device)
79+
return
80+
81+
results = []
82+
for device in args.devices.split(","):
83+
print("device", device)
84+
t = benchmark.Timer(
85+
stmt="decode_full_video(video_path, device)",
86+
globals={
87+
"device": device,
88+
"video_path": video_path,
89+
"decode_full_video": decode_full_video,
90+
},
91+
label="Decode+Resize Time",
92+
sub_label=f"video={os.path.basename(video_path)}",
93+
description=f"decode_device={device}",
94+
).blocked_autorange()
95+
results.append(t)
96+
compare = benchmark.Compare(results)
97+
compare.print()
98+
99+
100+
if __name__ == "__main__":
101+
main()

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,16 @@ def _build_all_extensions_with_cmake(self):
112112
torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch"
113113
cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release")
114114
python_version = sys.version_info
115+
enable_cuda = os.environ.get("ENABLE_CUDA", "")
116+
enable_nvtx = os.environ.get("ENABLE_NVTX", "")
115117
cmake_args = [
116118
f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}",
117119
f"-DTorch_DIR={torch_dir}",
118120
"-DCMAKE_VERBOSE_MAKEFILE=ON",
119121
f"-DCMAKE_BUILD_TYPE={cmake_build_type}",
120122
f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}",
123+
f"-DENABLE_CUDA={enable_cuda}",
124+
f"-DENABLE_NVTX={enable_nvtx}",
121125
]
122126

123127
Path(self.build_temp).mkdir(parents=True, exist_ok=True)

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@ set(CMAKE_CXX_STANDARD 17)
44
set(CMAKE_CXX_STANDARD_REQUIRED ON)
55

66
find_package(Torch REQUIRED)
7+
8+
if(ENABLE_CUDA)
9+
find_package(CUDA REQUIRED)
10+
11+
if(ENABLE_NVTX)
12+
# We only need CPM for NVTX:
13+
# https://github.com/NVIDIA/NVTX#cmake
14+
file(
15+
DOWNLOAD
16+
https://github.com/cpm-cmake/CPM.cmake/releases/download/v0.38.3/CPM.cmake
17+
${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake
18+
EXPECTED_HASH SHA256=cc155ce02e7945e7b8967ddfaff0b050e958a723ef7aad3766d368940cb15494
19+
)
20+
include(${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake)
21+
CPMAddPackage(
22+
NAME NVTX
23+
GITHUB_REPOSITORY NVIDIA/NVTX
24+
GIT_TAG v3.1.0-c-cpp
25+
GIT_SHALLOW TRUE)
26+
endif()
27+
endif()
28+
729
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
830
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
931

@@ -19,6 +41,12 @@ function(make_torchcodec_library library_name ffmpeg_target)
1941
)
2042
add_library(${library_name} SHARED ${sources})
2143
set_property(TARGET ${library_name} PROPERTY CXX_STANDARD 17)
44+
if(ENABLE_CUDA)
45+
target_compile_definitions(${library_name} PRIVATE ENABLE_CUDA=1)
46+
endif()
47+
if(ENABLE_NVTX)
48+
target_compile_definitions(${library_name} PRIVATE ENABLE_NVTX=1)
49+
endif()
2250

2351
target_include_directories(
2452
${library_name}
@@ -28,12 +56,17 @@ function(make_torchcodec_library library_name ffmpeg_target)
2856
${Python3_INCLUDE_DIRS}
2957
)
3058

59+
set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES})
60+
if(ENABLE_CUDA)
61+
list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY})
62+
endif()
63+
if(ENABLE_NVTX)
64+
list(APPEND NEEDED_LIBRARIES nvtx3-cpp)
65+
endif()
3166
target_link_libraries(
3267
${library_name}
3368
PUBLIC
34-
${ffmpeg_target}
35-
${TORCH_LIBRARIES}
36-
${Python3_LIBRARIES}
69+
${NEEDED_LIBRARIES}
3770
)
3871

3972
# We already set the library_name to be libtorchcodecN, so we don't want

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ using UniqueAVFilterInOut = std::unique_ptr<
5757
Deleterp<AVFilterInOut, void, avfilter_inout_free>>;
5858
using UniqueAVIOContext = std::
5959
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
60+
using UniqueAVBufferRef =
61+
std::unique_ptr<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>;
6062

6163
// av_find_best_stream is not const-correct before commit:
6264
// https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c

0 commit comments

Comments
 (0)