Skip to content

Add video GPU decoder #5019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e25c169
[WIP] Add video GPU decoder
prabhat00155 Dec 1, 2021
58dd405
Expose use_dev_frame to python class and handle it internally
prabhat00155 Dec 7, 2021
7bdf86b
Fixed invalid argument CUDA error
prabhat00155 Dec 8, 2021
86e373e
Fixed empty and missing frames
prabhat00155 Dec 13, 2021
717ee01
Free remaining frames in the queue
prabhat00155 Dec 14, 2021
75f8f48
Added nv12 to yuv420 conversion support for host frames
prabhat00155 Dec 14, 2021
a94cc75
Merge branch 'master' into prabhat00155/gpu_decoder
prabhat00155 Dec 14, 2021
34b8205
Added unit test and cleaned up code
prabhat00155 Dec 15, 2021
6672a05
Use CUDA_HOME inside if
prabhat00155 Dec 15, 2021
30af8ac
Undo commented out code
prabhat00155 Dec 15, 2021
4b9bab8
Add Readme
prabhat00155 Dec 15, 2021
5afb6dd
Remove output_format and use_device_frame optional arguments from the…
prabhat00155 Dec 16, 2021
e8ae42e
Cleaned up init()
prabhat00155 Dec 17, 2021
6433785
Fix warnings
prabhat00155 Dec 17, 2021
962962a
Fix python linter errors
prabhat00155 Dec 17, 2021
4dd1798
Fix linter issues in setup.py
prabhat00155 Dec 17, 2021
116fd02
clang-format
prabhat00155 Dec 17, 2021
5a08055
Make reformat private
prabhat00155 Dec 19, 2021
fd30a89
Member function naming
prabhat00155 Dec 19, 2021
87ed21e
Add comments
prabhat00155 Dec 19, 2021
f6b6cfe
Variable renaming
prabhat00155 Dec 19, 2021
3e309a5
Code cleanup
prabhat00155 Dec 20, 2021
d116edc
Make return type of decode() void
prabhat00155 Dec 20, 2021
c01b66b
Replace printing errors with throwing runtime_error
prabhat00155 Dec 21, 2021
7e0d884
Replaced runtime_error with TORCH_CHECK in demuxer.h
prabhat00155 Dec 22, 2021
901501c
Use CUDAGuard instead of cudaSetDevice
prabhat00155 Dec 22, 2021
2365906
Remove printf
prabhat00155 Dec 22, 2021
c77b558
Use Tensor instead of uint8* and remove cuMemAlloc/cuMemFree
prabhat00155 Dec 22, 2021
e8c80ed
Use TORCH_CHECK instead of runtime_error
prabhat00155 Dec 23, 2021
9d78ce5
Use TORCHVISION_INCLUDE and TORCHVISION_LIBRARY to pass video codec l…
prabhat00155 Dec 23, 2021
f733a97
Include ffmpeg_include_dir
prabhat00155 Dec 23, 2021
d794cb1
Remove space
prabhat00155 Dec 23, 2021
53a20b2
Removed use of runtime_error
prabhat00155 Dec 23, 2021
7ca13b7
Update Readme
prabhat00155 Dec 24, 2021
83ac2b1
Check for bsf.h
prabhat00155 Dec 24, 2021
caf45fd
Fixed merge conflicts
prabhat00155 Dec 24, 2021
d69f820
Change struct initialisation style
prabhat00155 Dec 24, 2021
5c5162e
Clean-up get_operating_point
prabhat00155 Dec 24, 2021
83d84b0
Make variable naming convention uniform
prabhat00155 Dec 24, 2021
d8d0fb5
Move checking for bsf.h around
prabhat00155 Dec 24, 2021
559639e
Fix linter error
prabhat00155 Dec 24, 2021
de8bfbd
Merge branch 'main' into prabhat00155/gpu_decoder
fmassa Dec 30, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,59 @@ def get_extensions():
)
)

# Locating video codec
# CUDA_HOME should be set to the cuda root directory.
# TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
# video codec header files and libraries respectively.
video_codec_found = (
extension is CUDAExtension
and CUDA_HOME is not None
and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include])
and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include])
and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs])
)

print(f"video codec found: {video_codec_found}")

if (
video_codec_found
and has_ffmpeg
and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu")
gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
cuda_libs = os.path.join(CUDA_HOME, "lib64")
cuda_inc = os.path.join(CUDA_HOME, "include")

ext_modules.append(
extension(
"torchvision.Decoder",
gpu_decoder_src,
include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs],
libraries=[
"avcodec",
"avformat",
"avutil",
"swresample",
"swscale",
"nvcuvid",
"cuda",
"cudart",
"z",
"pthread",
"dl",
],
extra_compile_args=extra_compile_args,
)
)
else:
print(
"The installed version of ffmpeg is missing the header file 'bsf.h' which is "
"required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
" `conda install -c conda-forge ffmpeg`."
)

return ext_modules


Expand Down
41 changes: 41 additions & 0 deletions test/test_video_gpu_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os

import pytest
import torch
from torchvision.io import _HAS_VIDEO_DECODER, VideoReader

try:
import av
except ImportError:
av = None

VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")

test_videos = [
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
"v_SoccerJuggling_g23_c01.avi",
"v_SoccerJuggling_g24_c01.avi",
"R6llTwEh07w.mp4",
"SOX5yA1l24A.mp4",
"WUzgd7C1pWA.mp4",
]


@pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
class TestVideoGPUDecoder:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_frame_reading(self):
for test_video in test_videos:
full_path = os.path.join(VIDEO_DIR, test_video)
decoder = VideoReader(full_path, device="cuda:0")
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_ndarray().flatten())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float()))
assert mean_delta < 0.1


if __name__ == "__main__":
pytest.main([__file__])
21 changes: 21 additions & 0 deletions torchvision/csrc/io/decoder/gpu/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
GPU Decoder
===========

GPU decoder depends on ffmpeg for demuxing, uses NVDECODE APIs from the nvidia-video-codec sdk and uses cuda for processing on gpu. In order to use this, please follow the following steps:

* Download the latest `nvidia-video-codec-sdk <https://developer.nvidia.com/nvidia-video-codec-sdk/download>`_
* Extract the zipped file.
* Set TORCHVISION_INCLUDE environment variable to the location of the video codec headers(`nvcuvid.h` and `cuviddec.h`), which would be under `Interface` directory.
* Set TORCHVISION_LIBRARY environment variable to the location of the video codec library(`libnvcuvid.so`), which would be under `Lib/linux/stubs/x86_64` directory.
* Install the latest ffmpeg from `conda-forge` channel.

.. code:: bash

conda install -c conda-forge ffmpeg

* Set CUDA_HOME environment variable to the cuda root directory.
* Build torchvision from source:

.. code:: bash

python setup.py install
Loading