Skip to content

Commit 44ae3d5

Browse files
authored
CPU fallback: do color-conversion on GPU. (#992)
1 parent 28b0346 commit 44ae3d5

File tree

7 files changed

+265
-112
lines changed

7 files changed

+265
-112
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 148 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,12 @@ bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) {
213213
return true;
214214
}
215215

216+
// Callback for freeing CUDA memory associated with AVFrame see where it's used
217+
// for more details.
218+
void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) {
219+
cudaFree(opaque);
220+
}
221+
216222
} // namespace
217223

218224
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -668,38 +674,163 @@ void BetaCudaDeviceInterface::flush() {
668674
std::swap(readyFrames_, emptyQueue);
669675
}
670676

677+
UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
678+
UniqueAVFrame& cpuFrame) {
679+
// This is called in the context of the CPU fallback: the frame was decoded on
680+
// the CPU, and in this function we convert that frame into NV12 format and
681+
// send it to the GPU.
682+
// We do that in 2 steps:
683+
// - First we convert the input CPU frame into an intermediate NV12 CPU frame
684+
// using sws_scale.
685+
// - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This
686+
// is what we return
687+
688+
TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null");
689+
690+
int width = cpuFrame->width;
691+
int height = cpuFrame->height;
692+
693+
// intermediate NV12 CPU frame. It's not on the GPU yet.
694+
UniqueAVFrame nv12CpuFrame(av_frame_alloc());
695+
TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame");
696+
697+
nv12CpuFrame->format = AV_PIX_FMT_NV12;
698+
nv12CpuFrame->width = width;
699+
nv12CpuFrame->height = height;
700+
701+
int ret = av_frame_get_buffer(nv12CpuFrame.get(), 0);
702+
TORCH_CHECK(
703+
ret >= 0,
704+
"Failed to allocate NV12 CPU frame buffer: ",
705+
getFFMPEGErrorStringFromErrorCode(ret));
706+
707+
SwsFrameContext swsFrameContext(
708+
width,
709+
height,
710+
static_cast<AVPixelFormat>(cpuFrame->format),
711+
width,
712+
height);
713+
714+
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
715+
swsContext_ = createSwsContext(
716+
swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR);
717+
prevSwsFrameContext_ = swsFrameContext;
718+
}
719+
720+
int convertedHeight = sws_scale(
721+
swsContext_.get(),
722+
cpuFrame->data,
723+
cpuFrame->linesize,
724+
0,
725+
height,
726+
nv12CpuFrame->data,
727+
nv12CpuFrame->linesize);
728+
TORCH_CHECK(
729+
convertedHeight == height, "sws_scale failed for CPU->NV12 conversion");
730+
731+
int ySize = width * height;
732+
TORCH_CHECK(
733+
ySize % 2 == 0,
734+
"Y plane size must be even. Please report on TorchCodec repo.");
735+
int uvSize = ySize / 2; // NV12: UV plane is half the size of Y plane
736+
size_t totalSize = static_cast<size_t>(ySize + uvSize);
737+
738+
uint8_t* cudaBuffer = nullptr;
739+
cudaError_t err =
740+
cudaMalloc(reinterpret_cast<void**>(&cudaBuffer), totalSize);
741+
TORCH_CHECK(
742+
err == cudaSuccess,
743+
"Failed to allocate CUDA memory: ",
744+
cudaGetErrorString(err));
745+
746+
UniqueAVFrame gpuFrame(av_frame_alloc());
747+
TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame");
748+
749+
gpuFrame->format = AV_PIX_FMT_CUDA;
750+
gpuFrame->width = width;
751+
gpuFrame->height = height;
752+
gpuFrame->data[0] = cudaBuffer;
753+
gpuFrame->data[1] = cudaBuffer + ySize;
754+
gpuFrame->linesize[0] = width;
755+
gpuFrame->linesize[1] = width;
756+
757+
// Note that we use cudaMemcpy2D here instead of cudaMemcpy because the
758+
// linesizes (strides) may be different than the widths for the input CPU
759+
// frame. That's precisely what cudaMemcpy2D is for.
760+
err = cudaMemcpy2D(
761+
gpuFrame->data[0],
762+
gpuFrame->linesize[0],
763+
nv12CpuFrame->data[0],
764+
nv12CpuFrame->linesize[0],
765+
width,
766+
height,
767+
cudaMemcpyHostToDevice);
768+
TORCH_CHECK(
769+
err == cudaSuccess,
770+
"Failed to copy Y plane to GPU: ",
771+
cudaGetErrorString(err));
772+
773+
TORCH_CHECK(
774+
height % 2 == 0,
775+
"height must be even. Please report on TorchCodec repo.");
776+
err = cudaMemcpy2D(
777+
gpuFrame->data[1],
778+
gpuFrame->linesize[1],
779+
nv12CpuFrame->data[1],
780+
nv12CpuFrame->linesize[1],
781+
width,
782+
height / 2,
783+
cudaMemcpyHostToDevice);
784+
TORCH_CHECK(
785+
err == cudaSuccess,
786+
"Failed to copy UV plane to GPU: ",
787+
cudaGetErrorString(err));
788+
789+
ret = av_frame_copy_props(gpuFrame.get(), cpuFrame.get());
790+
TORCH_CHECK(
791+
ret >= 0,
792+
"Failed to copy frame properties: ",
793+
getFFMPEGErrorStringFromErrorCode(ret));
794+
795+
// We're almost done, but we need to make sure the CUDA memory is freed
796+
// properly. Usually, AVFrame data is freed when av_frame_free() is called
797+
// (upon UniqueAVFrame destruction), but since we allocated the CUDA memory
798+
// ourselves, FFmpeg doesn't know how to free it. The recommended way to deal
799+
// with this is to associate the opaque_ref field of the AVFrame with a `free`
800+
// callback that will then be called by av_frame_free().
801+
gpuFrame->opaque_ref = av_buffer_create(
802+
nullptr, // data - we don't need any
803+
0, // data size
804+
cudaBufferFreeCallback, // callback triggered by av_frame_free()
805+
cudaBuffer, // parameter to callback
806+
0); // flags
807+
TORCH_CHECK(
808+
gpuFrame->opaque_ref != nullptr,
809+
"Failed to create GPU memory cleanup reference");
810+
811+
return gpuFrame;
812+
}
813+
671814
void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
672815
UniqueAVFrame& avFrame,
673816
FrameOutput& frameOutput,
674817
std::optional<torch::Tensor> preAllocatedOutputTensor) {
675-
if (cpuFallback_) {
676-
// CPU decoded frame - need to do CPU color conversion then transfer to GPU
677-
FrameOutput cpuFrameOutput;
678-
cpuFallback_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
679-
680-
// Transfer CPU frame to GPU
681-
if (preAllocatedOutputTensor.has_value()) {
682-
preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data);
683-
frameOutput.data = preAllocatedOutputTensor.value();
684-
} else {
685-
frameOutput.data = cpuFrameOutput.data.to(device_);
686-
}
687-
return;
688-
}
818+
UniqueAVFrame gpuFrame =
819+
cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame);
689820

690821
// TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA
691822
// ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24().
692823
TORCH_CHECK(
693-
avFrame->format == AV_PIX_FMT_CUDA,
824+
gpuFrame->format == AV_PIX_FMT_CUDA,
694825
"Expected CUDA format frame from BETA CUDA interface");
695826

696-
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);
827+
validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame);
697828

698829
at::cuda::CUDAStream nvdecStream =
699830
at::cuda::getCurrentCUDAStream(device_.index());
700831

701832
frameOutput.data = convertNV12FrameToRGB(
702-
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
833+
gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
703834
}
704835

705836
std::string BetaCudaDeviceInterface::getDetails() {

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
8181
unsigned int pitch,
8282
const CUVIDPARSERDISPINFO& dispInfo);
8383

84+
UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame);
85+
8486
CUvideoparser videoParser_ = nullptr;
8587
UniqueCUvideodecoder decoder_;
8688
CUVIDEOFORMAT videoFormat_ = {};
@@ -99,6 +101,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
99101

100102
std::unique_ptr<DeviceInterface> cpuFallback_;
101103
bool nvcuvidAvailable_ = false;
104+
UniqueSwsContext swsContext_;
105+
SwsFrameContext prevSwsFrameContext_;
102106
};
103107

104108
} // namespace facebook::torchcodec

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,6 @@ static bool g_cpu = registerDeviceInterface(
1515

1616
} // namespace
1717

18-
CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19-
int inputWidth,
20-
int inputHeight,
21-
AVPixelFormat inputFormat,
22-
int outputWidth,
23-
int outputHeight)
24-
: inputWidth(inputWidth),
25-
inputHeight(inputHeight),
26-
inputFormat(inputFormat),
27-
outputWidth(outputWidth),
28-
outputHeight(outputHeight) {}
29-
30-
bool CpuDeviceInterface::SwsFrameContext::operator==(
31-
const CpuDeviceInterface::SwsFrameContext& other) const {
32-
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
33-
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
34-
outputHeight == other.outputHeight;
35-
}
36-
37-
bool CpuDeviceInterface::SwsFrameContext::operator!=(
38-
const CpuDeviceInterface::SwsFrameContext& other) const {
39-
return !(*this == other);
40-
}
41-
4218
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4319
: DeviceInterface(device) {
4420
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
@@ -257,7 +233,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
257233
outputDims.height);
258234

259235
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
260-
createSwsContext(swsFrameContext, avFrame->colorspace);
236+
swsContext_ = createSwsContext(
237+
swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_);
261238
prevSwsFrameContext_ = swsFrameContext;
262239
}
263240

@@ -276,51 +253,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
276253
return resultHeight;
277254
}
278255

279-
void CpuDeviceInterface::createSwsContext(
280-
const SwsFrameContext& swsFrameContext,
281-
const enum AVColorSpace colorspace) {
282-
SwsContext* swsContext = sws_getContext(
283-
swsFrameContext.inputWidth,
284-
swsFrameContext.inputHeight,
285-
swsFrameContext.inputFormat,
286-
swsFrameContext.outputWidth,
287-
swsFrameContext.outputHeight,
288-
AV_PIX_FMT_RGB24,
289-
swsFlags_,
290-
nullptr,
291-
nullptr,
292-
nullptr);
293-
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");
294-
295-
int* invTable = nullptr;
296-
int* table = nullptr;
297-
int srcRange, dstRange, brightness, contrast, saturation;
298-
int ret = sws_getColorspaceDetails(
299-
swsContext,
300-
&invTable,
301-
&srcRange,
302-
&table,
303-
&dstRange,
304-
&brightness,
305-
&contrast,
306-
&saturation);
307-
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");
308-
309-
const int* colorspaceTable = sws_getCoefficients(colorspace);
310-
ret = sws_setColorspaceDetails(
311-
swsContext,
312-
colorspaceTable,
313-
srcRange,
314-
colorspaceTable,
315-
dstRange,
316-
brightness,
317-
contrast,
318-
saturation);
319-
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");
320-
321-
swsContext_.reset(swsContext);
322-
}
323-
324256
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
325257
const UniqueAVFrame& avFrame,
326258
const FrameDims& outputDims) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,6 @@ class CpuDeviceInterface : public DeviceInterface {
5454
ColorConversionLibrary getColorConversionLibrary(
5555
const FrameDims& inputFrameDims) const;
5656

57-
struct SwsFrameContext {
58-
int inputWidth = 0;
59-
int inputHeight = 0;
60-
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
61-
int outputWidth = 0;
62-
int outputHeight = 0;
63-
64-
SwsFrameContext() = default;
65-
SwsFrameContext(
66-
int inputWidth,
67-
int inputHeight,
68-
AVPixelFormat inputFormat,
69-
int outputWidth,
70-
int outputHeight);
71-
bool operator==(const SwsFrameContext&) const;
72-
bool operator!=(const SwsFrameContext&) const;
73-
};
74-
75-
void createSwsContext(
76-
const SwsFrameContext& swsFrameContext,
77-
const enum AVColorSpace colorspace);
78-
7957
VideoStreamOptions videoStreamOptions_;
8058
AVRational timeBase_;
8159

0 commit comments

Comments
 (0)