Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvjpeg hardware acc #5240

Merged
merged 13 commits into from
Jun 29, 2021
199 changes: 150 additions & 49 deletions oneflow/core/kernel/image_decoder_random_crop_resize_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,49 @@ struct Work {
std::shared_ptr<std::atomic<int>> task_counter;
};

struct ROI {
int x;
int y;
int w;
int h;
};

class ROIGenerator {
public:
virtual ~ROIGenerator() = default;
virtual void Generate(int width, int height, ROI* roi) const = 0;
};

class RandomCropROIGenerator : public ROIGenerator {
public:
RandomCropROIGenerator(RandomCropGenerator* crop_generator) : crop_generator_(crop_generator) {}
~RandomCropROIGenerator() override = default;

void Generate(int width, int height, ROI* roi) const override {
CropWindow window;
crop_generator_->GenerateCropWindow({height, width}, &window);
roi->x = window.anchor.At(1);
roi->y = window.anchor.At(0);
roi->w = window.shape.At(1);
roi->h = window.shape.At(0);
}

private:
RandomCropGenerator* crop_generator_;
};

class NoChangeROIGenerator : public ROIGenerator {
public:
~NoChangeROIGenerator() override = default;

void Generate(int width, int height, ROI* roi) const override {
roi->x = 0;
roi->y = 0;
roi->w = width;
roi->h = height;
}
};

void GenerateRandomCropRoi(RandomCropGenerator* crop_generator, int width, int height, int* roi_x,
int* roi_y, int* roi_width, int* roi_height) {
CropWindow window;
Expand All @@ -74,7 +117,7 @@ class DecodeHandle {

using DecodeHandleFactory = std::function<std::shared_ptr<DecodeHandle>()>;
template<DeviceType device_type>
DecodeHandleFactory CreateDecodeHandleFactory();
DecodeHandleFactory CreateDecodeHandleFactory(int target_width, int target_height);

class CpuDecodeHandle final : public DecodeHandle {
public:
Expand Down Expand Up @@ -117,7 +160,8 @@ void CpuDecodeHandle::DecodeRandomCropResize(const unsigned char* data, size_t l
}

template<>
DecodeHandleFactory CreateDecodeHandleFactory<DeviceType::kCPU>() {
DecodeHandleFactory CreateDecodeHandleFactory<DeviceType::kCPU>(int target_width,
int target_height) {
return []() -> std::shared_ptr<DecodeHandle> { return std::make_shared<CpuDecodeHandle>(); };
}

Expand Down Expand Up @@ -155,7 +199,7 @@ void InitNppStreamContext(NppStreamContext* ctx, int dev, cudaStream_t stream) {
class GpuDecodeHandle final : public DecodeHandle {
public:
OF_DISALLOW_COPY_AND_MOVE(GpuDecodeHandle);
explicit GpuDecodeHandle(int dev);
explicit GpuDecodeHandle(int dev, int target_width, int target_height);
~GpuDecodeHandle() override;

void DecodeRandomCropResize(const unsigned char* data, size_t length,
Expand All @@ -166,29 +210,35 @@ class GpuDecodeHandle final : public DecodeHandle {
void Synchronize() override;

private:
void DecodeRandomCrop(const unsigned char* data, size_t length,
RandomCropGenerator* crop_generator, unsigned char* dst,
size_t dst_max_length, int* dst_width, int* dst_height);
void DecodeRandomCrop(const unsigned char* data, size_t length, ROIGenerator* roi_generator,
unsigned char* dst, size_t dst_max_length, int* dst_width, int* dst_height);
void Decode(const unsigned char* data, size_t length, unsigned char* dst, size_t dst_max_length,
int* dst_width, int* dst_height);
void Resize(const unsigned char* src, int src_width, int src_height, unsigned char* dst,
int dst_width, int dst_height);
void CropResize(const unsigned char* src, int src_width, int src_height,
ROIGenerator* roi_generator, unsigned char* dst, int dst_width, int dst_height);

cudaStream_t cuda_stream_ = nullptr;
nvjpegHandle_t jpeg_handle_ = nullptr;
nvjpegJpegState_t jpeg_state_ = nullptr;
nvjpegJpegState_t hw_jpeg_state_ = nullptr;
nvjpegBufferPinned_t jpeg_pinned_buffer_ = nullptr;
nvjpegBufferDevice_t jpeg_device_buffer_ = nullptr;
nvjpegDecodeParams_t jpeg_decode_params_ = nullptr;
nvjpegJpegDecoder_t jpeg_decoder_ = nullptr;
nvjpegJpegDecoder_t hw_jpeg_decoder_ = nullptr;
nvjpegJpegStream_t jpeg_stream_ = nullptr;
NppStreamContext npp_stream_ctx_{};
nvjpegDevAllocator_t dev_allocator_{};
nvjpegPinnedAllocator_t pinned_allocator_{};
CpuDecodeHandle fallback_handle_;
unsigned char* fallback_buffer_;
size_t fallback_buffer_size_;
bool warmup_done_;
bool use_hardware_acceleration_;
};

GpuDecodeHandle::GpuDecodeHandle(int dev) : warmup_done_(false) {
GpuDecodeHandle::GpuDecodeHandle(int dev, int target_width, int target_height)
: warmup_done_(false), use_hardware_acceleration_(false) {
OF_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking));
dev_allocator_.dev_malloc = &GpuDeviceMalloc;
dev_allocator_.dev_free = &GpuDeviceFree;
Expand All @@ -198,11 +248,23 @@ GpuDecodeHandle::GpuDecodeHandle(int dev) : warmup_done_(false) {
&jpeg_handle_));
OF_NVJPEG_CHECK(nvjpegDecoderCreate(jpeg_handle_, NVJPEG_BACKEND_DEFAULT, &jpeg_decoder_));
OF_NVJPEG_CHECK(nvjpegDecoderStateCreate(jpeg_handle_, jpeg_decoder_, &jpeg_state_));
#if NVJPEG_VER_MAJOR >= 11
if (nvjpegDecoderCreate(jpeg_handle_, NVJPEG_BACKEND_HARDWARE, &hw_jpeg_decoder_)
== NVJPEG_STATUS_SUCCESS) {
OF_NVJPEG_CHECK(nvjpegDecoderStateCreate(jpeg_handle_, hw_jpeg_decoder_, &hw_jpeg_state_));
use_hardware_acceleration_ = true;
} else {
hw_jpeg_decoder_ = nullptr;
hw_jpeg_state_ = nullptr;
}
#endif
OF_NVJPEG_CHECK(nvjpegBufferPinnedCreate(jpeg_handle_, &pinned_allocator_, &jpeg_pinned_buffer_));
OF_NVJPEG_CHECK(nvjpegBufferDeviceCreate(jpeg_handle_, &dev_allocator_, &jpeg_device_buffer_));
OF_NVJPEG_CHECK(nvjpegDecodeParamsCreate(jpeg_handle_, &jpeg_decode_params_));
OF_NVJPEG_CHECK(nvjpegJpegStreamCreate(jpeg_handle_, &jpeg_stream_));
InitNppStreamContext(&npp_stream_ctx_, dev, cuda_stream_);
fallback_buffer_size_ = target_width * target_height * kNumChannels;
OF_CUDA_CHECK(cudaMallocHost(&fallback_buffer_, fallback_buffer_size_));
}

GpuDecodeHandle::~GpuDecodeHandle() {
Expand All @@ -213,66 +275,79 @@ GpuDecodeHandle::~GpuDecodeHandle() {
OF_NVJPEG_CHECK(nvjpegBufferPinnedDestroy(jpeg_pinned_buffer_));
OF_NVJPEG_CHECK(nvjpegJpegStateDestroy(jpeg_state_));
OF_NVJPEG_CHECK(nvjpegDecoderDestroy(jpeg_decoder_));
if (use_hardware_acceleration_) {
OF_NVJPEG_CHECK(nvjpegJpegStateDestroy(hw_jpeg_state_));
OF_NVJPEG_CHECK(nvjpegDecoderDestroy(hw_jpeg_decoder_));
}
OF_NVJPEG_CHECK(nvjpegDestroy(jpeg_handle_));
OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_));
OF_CUDA_CHECK(cudaFreeHost(fallback_buffer_));
}

void GpuDecodeHandle::DecodeRandomCrop(const unsigned char* data, size_t length,
RandomCropGenerator* crop_generator, unsigned char* dst,
ROIGenerator* roi_generator, unsigned char* dst,
size_t dst_max_length, int* dst_width, int* dst_height) {
// https://docs.nvidia.com/cuda/archive/10.2/nvjpeg/index.html#nvjpeg-decoupled-decode-api
OF_NVJPEG_CHECK(nvjpegJpegStreamParse(jpeg_handle_, data, length, 0, 0, jpeg_stream_));
unsigned int orig_width;
unsigned int orig_height;
OF_NVJPEG_CHECK(nvjpegJpegStreamGetFrameDimensions(jpeg_stream_, &orig_width, &orig_height));
int roi_x;
int roi_y;
int roi_width;
int roi_height;
if (crop_generator) {
GenerateRandomCropRoi(crop_generator, static_cast<int>(orig_width),
static_cast<int>(orig_height), &roi_x, &roi_y, &roi_width, &roi_height);
} else {
roi_x = 0;
roi_y = 0;
roi_width = static_cast<int>(orig_width);
roi_height = static_cast<int>(orig_height);
}
CHECK_LE(roi_width * roi_height * kNumChannels, dst_max_length);
ROI roi;
roi_generator->Generate(static_cast<int>(orig_width), static_cast<int>(orig_height), &roi);
CHECK_LE(roi.w * roi.h * kNumChannels, dst_max_length);
nvjpegImage_t image;
image.channel[0] = dst;
image.pitch[0] = roi_width * kNumChannels;
image.pitch[0] = roi.w * kNumChannels;
OF_NVJPEG_CHECK(nvjpegDecodeParamsSetOutputFormat(jpeg_decode_params_, NVJPEG_OUTPUT_RGBI));
OF_NVJPEG_CHECK(
nvjpegDecodeParamsSetROI(jpeg_decode_params_, roi_x, roi_y, roi_width, roi_height));
OF_NVJPEG_CHECK(nvjpegStateAttachPinnedBuffer(jpeg_state_, jpeg_pinned_buffer_));
OF_NVJPEG_CHECK(nvjpegStateAttachDeviceBuffer(jpeg_state_, jpeg_device_buffer_));
OF_NVJPEG_CHECK(nvjpegDecodeJpegHost(jpeg_handle_, jpeg_decoder_, jpeg_state_,
jpeg_decode_params_, jpeg_stream_));
OF_NVJPEG_CHECK(nvjpegDecodeJpegTransferToDevice(jpeg_handle_, jpeg_decoder_, jpeg_state_,

nvjpegJpegDecoder_t jpeg_decoder;
nvjpegJpegState_t jpeg_state;
int is_hardware_acceleration_supported = -1;
if (use_hardware_acceleration_) {
nvjpegDecoderJpegSupported(hw_jpeg_decoder_, jpeg_stream_, jpeg_decode_params_,
&is_hardware_acceleration_supported);
}
if (is_hardware_acceleration_supported == 0) {
jpeg_decoder = hw_jpeg_decoder_;
jpeg_state = hw_jpeg_state_;
} else {
jpeg_decoder = jpeg_decoder_;
jpeg_state = jpeg_state_;
OF_NVJPEG_CHECK(nvjpegDecodeParamsSetROI(jpeg_decode_params_, roi.x, roi.y, roi.w, roi.h));
}
OF_NVJPEG_CHECK(nvjpegStateAttachPinnedBuffer(jpeg_state, jpeg_pinned_buffer_));
OF_NVJPEG_CHECK(nvjpegStateAttachDeviceBuffer(jpeg_state, jpeg_device_buffer_));
OF_NVJPEG_CHECK(nvjpegDecodeJpegHost(jpeg_handle_, jpeg_decoder, jpeg_state, jpeg_decode_params_,
jpeg_stream_));
OF_NVJPEG_CHECK(nvjpegDecodeJpegTransferToDevice(jpeg_handle_, jpeg_decoder, jpeg_state,
jpeg_stream_, cuda_stream_));
OF_NVJPEG_CHECK(
nvjpegDecodeJpegDevice(jpeg_handle_, jpeg_decoder_, jpeg_state_, &image, cuda_stream_));
*dst_width = roi_width;
*dst_height = roi_height;
nvjpegDecodeJpegDevice(jpeg_handle_, jpeg_decoder, jpeg_state, &image, cuda_stream_));
*dst_width = roi.w;
*dst_height = roi.h;
}

void GpuDecodeHandle::Decode(const unsigned char* data, size_t length, unsigned char* dst,
size_t dst_max_length, int* dst_width, int* dst_height) {
DecodeRandomCrop(data, length, nullptr, dst, dst_max_length, dst_width, dst_height);
NoChangeROIGenerator no_change_roi_generator;
DecodeRandomCrop(data, length, &no_change_roi_generator, dst, dst_max_length, dst_width,
dst_height);
}

void GpuDecodeHandle::Resize(const unsigned char* src, int src_width, int src_height,
unsigned char* dst, int dst_width, int dst_height) {
void GpuDecodeHandle::CropResize(const unsigned char* src, int src_width, int src_height,
ROIGenerator* roi_generator, unsigned char* dst, int dst_width,
int dst_height) {
ROI roi;
roi_generator->Generate(static_cast<int>(src_width), static_cast<int>(src_height), &roi);
const NppiSize src_size{
.width = src_width,
.height = src_height,
};
const NppiRect src_rect{
.x = 0,
.y = 0,
.width = src_width,
.height = src_height,
.x = roi.x,
.y = roi.y,
.width = roi.w,
.height = roi.h,
};
const NppiSize dst_size{
.width = dst_width,
Expand All @@ -297,8 +372,32 @@ void GpuDecodeHandle::DecodeRandomCropResize(const unsigned char* data, size_t l
int target_height) {
int width;
int height;
DecodeRandomCrop(data, length, crop_generator, workspace, workspace_size, &width, &height);
Resize(workspace, width, height, dst, target_width, target_height);
nvjpegChromaSubsampling_t subsampling;
int num_components;
nvjpegStatus_t status = nvjpegGetImageInfo(jpeg_handle_, data, length, &num_components,
&subsampling, &width, &height);
if (status != NVJPEG_STATUS_SUCCESS) {
CHECK_LE(target_width * target_height * kNumChannels, fallback_buffer_size_);
fallback_handle_.DecodeRandomCropResize(data, length, crop_generator, nullptr, 0,
fallback_buffer_, target_width, target_height);
OF_CUDA_CHECK(cudaMemcpyAsync(dst, fallback_buffer_,
target_width * target_height * kNumChannels, cudaMemcpyDefault,
cuda_stream_));
return;
}
NoChangeROIGenerator no_change_roi_generator;
RandomCropROIGenerator random_crop_roi_generator(crop_generator);
if (use_hardware_acceleration_) {
DecodeRandomCrop(data, length, &no_change_roi_generator, workspace, workspace_size, &width,
&height);
CropResize(workspace, width, height, &random_crop_roi_generator, dst, target_width,
target_height);
} else {
DecodeRandomCrop(data, length, &random_crop_roi_generator, workspace, workspace_size, &width,
&height);
CropResize(workspace, width, height, &no_change_roi_generator, dst, target_width,
target_height);
}
}

void GpuDecodeHandle::WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) {
Expand All @@ -318,13 +417,14 @@ void GpuDecodeHandle::WarmupOnce(int warmup_size, unsigned char* workspace, size
void GpuDecodeHandle::Synchronize() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); }

template<>
DecodeHandleFactory CreateDecodeHandleFactory<DeviceType::kGPU>() {
DecodeHandleFactory CreateDecodeHandleFactory<DeviceType::kGPU>(int target_width,
int target_height) {
int dev;
OF_CUDA_CHECK(cudaGetDevice(&dev));
return [dev]() -> std::shared_ptr<DecodeHandle> {
return [dev, target_width, target_height]() -> std::shared_ptr<DecodeHandle> {
OF_CUDA_CHECK(cudaSetDevice(dev));
CudaDeviceSetCpuAffinity(dev);
return std::make_shared<GpuDecodeHandle>(dev);
return std::make_shared<GpuDecodeHandle>(dev, target_width, target_height);
};
}

Expand Down Expand Up @@ -413,8 +513,9 @@ void ImageDecoderRandomCropResizeKernel<device_type>::VirtualKernelInit() {
}
workers_.resize(conf.num_workers());
for (int64_t i = 0; i < conf.num_workers(); ++i) {
workers_.at(i).reset(new Worker(CreateDecodeHandleFactory<device_type>(), conf.target_width(),
conf.target_height(), conf.warmup_size()));
workers_.at(i).reset(new Worker(
CreateDecodeHandleFactory<device_type>(conf.target_width(), conf.target_height()),
conf.target_width(), conf.target_height(), conf.warmup_size()));
}
}

Expand Down