@@ -156,6 +156,21 @@ const Npp32f bt709FullRangeColorTwist[3][4] = {
156156 {1 .0f , -0 .187324273f , -0 .468124273f , -128 .0f },
157157 {1 .0f , 1 .8556f , 0 .0f , -128 .0f }};
158158
159+ // RGB to NV12 color conversion matrices (inverse of YUV to RGB)
160+ // Note: NPP's ColorTwist function apparently expects "limited range"
161+ // coefficient format even when producing full range output. All matrices below
162+ // use the limited range coefficient format (Y with +16 offset) for NPP
163+ // compatibility.
164+
165+ // BT.601 limited range (matches FFmpeg default behavior)
166+ const Npp32f defaultLimitedRangeRgbToNv12[3 ][4 ] = {
167+ // Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B)
168+ {0 .257f , 0 .504f , 0 .098f , 16 .0f },
169+ // U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients)
170+ {-0 .148f , -0 .291f , 0 .439f , 128 .0f },
171+ // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients)
172+ {0 .439f , -0 .368f , -0 .071f , 128 .0f }};
173+
159174torch::Tensor convertNV12FrameToRGB (
160175 UniqueAVFrame& avFrame,
161176 const torch::Device& device,
@@ -246,6 +261,68 @@ torch::Tensor convertNV12FrameToRGB(
246261 return dst;
247262}
248263
264+ void convertRGBTensorToNV12Frame (
265+ const torch::Tensor& rgbTensor,
266+ UniqueAVFrame& nv12Frame,
267+ const torch::Device& device,
268+ const UniqueNppContext& nppCtx,
269+ at::cuda::CUDAStream inputStream) {
270+ TORCH_CHECK (rgbTensor.is_cuda (), " RGB tensor must be on CUDA device" );
271+ TORCH_CHECK (
272+ rgbTensor.dim () == 3 && rgbTensor.size (0 ) == 3 ,
273+ " Expected 3D RGB tensor in CHW format, got shape: " ,
274+ rgbTensor.sizes ());
275+ TORCH_CHECK (
276+ nv12Frame != nullptr && nv12Frame->data [0 ] != nullptr ,
277+ " nv12Frame must be pre-allocated with CUDA memory" );
278+
279+ // Convert CHW to HWC for NPP processing
280+ int height = static_cast <int >(rgbTensor.size (1 ));
281+ int width = static_cast <int >(rgbTensor.size (2 ));
282+ torch::Tensor hwcFrame = rgbTensor.permute ({1 , 2 , 0 }).contiguous ();
283+
284+ // Set up stream synchronization - make NPP stream wait for input tensor
285+ // operations
286+ at::cuda::CUDAStream nppStream =
287+ at::cuda::getCurrentCUDAStream (device.index ());
288+ at::cuda::CUDAEvent inputDoneEvent;
289+ inputDoneEvent.record (inputStream);
290+ inputDoneEvent.block (nppStream);
291+
292+ // Setup NPP context
293+ nppCtx->hStream = nppStream.stream ();
294+ cudaError_t cudaErr =
295+ cudaStreamGetFlags (nppCtx->hStream , &nppCtx->nStreamFlags );
296+ TORCH_CHECK (
297+ cudaErr == cudaSuccess,
298+ " cudaStreamGetFlags failed: " ,
299+ cudaGetErrorString (cudaErr));
300+
301+ // Always use FFmpeg's default behavior: BT.601 limited range
302+ NppiSize oSizeROI = {width, height};
303+
304+ NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx (
305+ static_cast <const Npp8u*>(hwcFrame.data_ptr ()),
306+ hwcFrame.stride (0 ) * hwcFrame.element_size (),
307+ nv12Frame->data ,
308+ nv12Frame->linesize ,
309+ oSizeROI,
310+ defaultLimitedRangeRgbToNv12,
311+ *nppCtx);
312+
313+ TORCH_CHECK (
314+ status == NPP_SUCCESS,
315+ " Failed to convert RGB to NV12: NPP error code " ,
316+ status);
317+
318+ // Validate CUDA operations completed successfully
319+ cudaError_t memCheck = cudaGetLastError ();
320+ TORCH_CHECK (
321+ memCheck == cudaSuccess,
322+ " CUDA error detected: " ,
323+ cudaGetErrorString (memCheck));
324+ }
325+
249326UniqueNppContext getNppStreamContext (const torch::Device& device) {
250327 int deviceIndex = getDeviceIndex (device);
251328
0 commit comments