Skip to content

Adapt NPP calls for CUDA >= 12.9 #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/linux_cuda_wheel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ jobs:
# For the actual release we should add that label and change this to
# include more python versions.
python-version: ['3.9']
cuda-version: ['12.6', '12.8']
# We test against 12.6 and 12.9 to avoid having too big of a CI matrix,
# but for releases we should add 12.8.
cuda-version: ['12.6', '12.9']
# TODO: put back ffmpeg 5 https://github.com/pytorch/torchcodec/issues/325
ffmpeg-version-for-tests: ['4.4.2', '6', '7']

Expand Down
57 changes: 41 additions & 16 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,41 +224,66 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
// Use the user-requested GPU for running the NPP kernel.
c10::cuda::CUDAGuard deviceGuard(device_);

cudaStream_t rawStream = at::cuda::getCurrentCUDAStream().stream();

// Build an NppStreamContext, either via the old helper or by hand on
// CUDA 12.9+
NppStreamContext nppCtx{};
#if CUDA_VERSION < 12090
NppStatus ctxStat = nppGetStreamContext(&nppCtx);
TORCH_CHECK(ctxStat == NPP_SUCCESS, "nppGetStreamContext failed");
// override if you want to force a particular stream
nppCtx.hStream = rawStream;
#else
// CUDA 12.9+: helper was removed, we need to build it manually
int dev = 0;
cudaError_t err = cudaGetDevice(&dev);
Comment on lines +239 to +240
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Kh4L , I actually have some questions before moving forward:

  1. Should we just rely on the existing device_ attribute instead of calling cudaGetDevice(&dev), or are they actually equivalent?
  2. Would it make sense to cache the nppCtx across calls? In this PR it looks like we're creating the context over and over for every single frame that needs to be decoded. I wonder if it might be beneficial to cache it in the class and re-use it?

Thanks for your help so far, I'm still trying to build familiarity with that part of the code base.

TORCH_CHECK(err == cudaSuccess, "cudaGetDevice failed");
cudaDeviceProp prop{};
err = cudaGetDeviceProperties(&prop, dev);
TORCH_CHECK(err == cudaSuccess, "cudaGetDeviceProperties failed");

nppCtx.nCudaDeviceId = dev;
nppCtx.nMultiProcessorCount = prop.multiProcessorCount;
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock;
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;
nppCtx.nStreamFlags = 0;
nppCtx.hStream = rawStream;
#endif

// Prepare ROI + pointers
NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};

auto start = std::chrono::high_resolution_clock::now();
NppStatus status;

if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
oSizeROI,
nppCtx);
} else {
status = nppiNV12ToRGB_8u_P2C3R(
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
oSizeROI,
nppCtx);
}
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");

// Make the pytorch stream wait for the npp kernel to finish before using the
// output.
at::cuda::CUDAEvent nppDoneEvent;
at::cuda::CUDAStream nppStreamWrapper =
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
nppDoneEvent.record(nppStreamWrapper);
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kh4L Can you confirm my understanding that the nppiNV12ToRGB_8u_P2C3R_Ctx call well properly wait on the stream before returning?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Do I understand correctly that in 12.9 we have to use the context-based API, while at the same time the context creation helper was removed?! This sounds error prone, is there any way we could avoid manually building and setting the context attributes?

Unfortunately, there is currently no alternative way to build this.
Since I’m not part of the npp team, I can’t comment on their design choices

@Kh4L Can you confirm my understanding that the nppiNV12ToRGB_8u_P2C3R_Ctx call well properly wait on the stream before returning?

That’s correct. We bind the NPP context to the active CUDA stream so we can leverage CUDA stream management rather than performing a blocking sync

https://github.com/pytorch/torchcodec/pull/757/files#diff-37d8a09669d3f009b6850f6e66888b6875d805064933148fce3a637cc7694712R254


auto end = std::chrono::high_resolution_clock::now();

std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
<< " took: " << duration.count() << "us" << std::endl;
auto duration = std::chrono::duration<double, std::micro>(end - start);
VLOG(9) << "NPP Conversion of frame h=" << height << " w=" << width
<< " took: " << duration.count() << "us";
}

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
Expand Down
Loading