Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Performance improvement in Normalize GPU Kernel #14139

Merged
Prev Previous commit
Next Next commit
New GPU kernel for Normalize
  • Loading branch information
sandeep-krishnamurthy committed Feb 14, 2019
commit 64d2b472c1ea7901e2c795b640f2020be8b8c731
68 changes: 59 additions & 9 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ void NormalizeImplCUDA(mshadow::Stream<gpu> *s,
const float std_d0,
const float std_d1,
const float std_d2);

template<typename DType, typename T>
void NormalizeBackwardImplCUDA(mshadow::Stream<gpu> *s,
const T out_grad,
const T in_grad,
const int req,
const float std_d0,
const float std_d1,
const float std_d2);
#endif // MXNET_USE_CUDA

// Shape and Type inference for image to tensor operator
Expand Down Expand Up @@ -281,11 +290,9 @@ inline void Normalize(DType* out_data,
#pragma omp parallel for collapse(2)
#endif // _MSC_VER
for (uint32_t c = 0; c < channels; ++c) {
float mean_c = mean[mean.size() > c ? c : 0];
float std_c = std[std.size() > c ? c : 0];
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + c*length + i] - mean_c) / std_c);
(in_data[step + c*length + i] - mean[c]) / std[c]);
}
}
}
Expand Down Expand Up @@ -339,8 +346,31 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
std[2] = param.std[2];
}

// 3D input (c, h, w)
if (inputs[0].ndim() == 3) {
if (std::is_same<xpu, gpu>::value) {
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (inputs[0].ndim() == 3) {
Tensor<gpu, 3, DType> input = inputs[0].get<gpu, 3, DType>(s);
Tensor<gpu, 3, DType> output = outputs[0].get<gpu, 3, DType>(s);
NormalizeImplCUDA<DType, Tensor<gpu, 3, DType>>
(s, input, output, req_type, mean[0], mean[1], mean[2],
std[0], std[1], std[2]);
} else {
Tensor<gpu, 4, DType> input = inputs[0].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> output = outputs[0].get<gpu, 4, DType>(s);
NormalizeImplCUDA<DType, Tensor<gpu, 4, DType>>
(s, input, output, req_type, mean[0], mean[1], mean[2],
std[0], std[1], std[2]);
}
});
});
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize operator on GPU.";
#endif // MXNET_USE_CUDA
} else if (inputs[0].ndim() == 3) {
// 3D input (c, h, w)
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const uint32_t channel = inputs[0].shape_[0];
const int step = 0;
Expand Down Expand Up @@ -374,10 +404,9 @@ inline void NormalizeBackward(const DType* out_grad,
#pragma omp parallel for collapse(2)
#endif // _MSC_VER
for (uint32_t c = 0; c < channels; ++c) {
float std_c = std[std.size() > c ? c : 0];
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(in_grad[step + c*length + i], req,
out_grad[step + c*length + i] * (1.0 / std_c));
out_grad[step + c*length + i] * (1.0 / std[c]));
}
}
}
Expand Down Expand Up @@ -424,8 +453,29 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
// Note: inputs[0] is out_grad
const TBlob& in_data = inputs[1];

// 3D input (c, h, w)
if (in_data.ndim() == 3) {
if (std::is_same<xpu, gpu>::value) {
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (in_data.ndim() == 3) {
Tensor<gpu, 3, DType> out_grad = inputs[0].get<gpu, 3, DType>(s);
Tensor<gpu, 3, DType> in_grad = outputs[0].get<gpu, 3, DType>(s);
NormalizeBackwardImplCUDA<DType, Tensor<gpu, 3, DType>>
(s, out_grad, in_grad, req_type, std[0], std[1], std[2]);
} else {
Tensor<gpu, 4, DType> out_grad = inputs[0].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> in_grad = outputs[0].get<gpu, 4, DType>(s);
NormalizeBackwardImplCUDA<DType, Tensor<gpu, 4, DType>>
(s, out_grad, in_grad, req_type, std[0], std[1], std[2]);
}
});
});
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize backward operator on GPU.";
#endif // MXNET_USE_CUDA
} else if (in_data.ndim() == 3) {
// 3D input (c, h, w)
const int length = in_data.shape_[1] * in_data.shape_[2];
const uint32_t channel = in_data.shape_[0];
const int step = 0;
Expand Down
227 changes: 226 additions & 1 deletion src/operator/image/image_random.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
W = input.size(2);
C = input.size(3);
blocks = N > 0 ? N : 1;
blocks = N;
}
// One block per image.
// Number of threads = (32, 32) is optimal, because,
Expand All @@ -111,6 +110,232 @@ void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel);
}

// Normalize Kernel for 3D input
template<typename xpu, typename DType>
__global__ void NormalizeCudaKernel(const Tensor<xpu, 3, DType> input,
const Tensor<xpu, 3, DType> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float mean_d0,
const float mean_d1,
const float mean_d2,
const float std_d0,
const float std_d1,
const float std_d2) {
// We process one image per thread block.
// In 3D case, we have only 1 block i.e., blockIdx.x
// We do not use it.

float mean = mean_d0;
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
float std = std_d0;
for (int c = 0; c < C; ++c) {
switch (c) {
case 0 : mean = mean_d0;
std = std_d0;
break;
case 1 : mean = mean_d1;
std = std_d1;
break;
case 2 : mean = mean_d2;
std = std_d2;
break;
}
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[c][h][w], req,
(input[c][h][w] - mean) / std);
}
}
}
}

// Normalize Kernel for 4D input
template<typename xpu, typename DType>
__global__ void NormalizeCudaKernel(const Tensor<xpu, 4, DType> input,
const Tensor<xpu, 4, DType> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float mean_d0,
const float mean_d1,
const float mean_d2,
const float std_d0,
const float std_d1,
const float std_d2) {
// We process one image per thread block.
const int n = blockIdx.x;

float mean = mean_d0;
float std = std_d0;
for (int c = 0; c < C; ++c) {
switch (c) {
case 0 : mean = mean_d0;
std = std_d0;
break;
case 1 : mean = mean_d1;
std = std_d1;
break;
case 2 : mean = mean_d2;
std = std_d2;
break;
}
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[n][c][h][w], req,
(input[n][c][h][w] - mean) / std);
}
}
}
}

template<typename DType, typename T>
void NormalizeImplCUDA(mshadow::Stream<gpu> *s,
const T input,
const T output,
const int req,
const float mean_d0,
const float mean_d1,
const float mean_d2,
const float std_d0,
const float std_d1,
const float std_d2) {
int blocks, H, W, C, N;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (std::is_same<T, Tensor<gpu, 3, DType>>::value) {
// 3D Input - (C, H, W)
N = 0;
C = input.size(0);
H = input.size(1);
W = input.size(2);
blocks = 1;
} else {
// 4D Input - (N, C, H, W)
N = input.size(0);
C = input.size(1);
H = input.size(2);
W = input.size(3);
blocks = N > 0 ? N : 1;
}
// One block per image.
// Number of threads = (16, 16) is optimal, because,
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
// computation is minimal and overhead of CUDA preparing
// all threads is minimal.
NormalizeCudaKernel<gpu, DType>
<<<blocks, dim3(16, 16), 0, stream>>>(input, output,
req, N, H, W, C, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2);
MSHADOW_CUDA_POST_KERNEL_CHECK(NormalizeCudaKernel);
}

// Normalize Backward Kernel for 3D input
template<typename xpu, typename DType>
__global__ void NormalizeBackwardCudaKernel(const Tensor<xpu, 3, DType> out_grad,
const Tensor<xpu, 3, DType> in_grad,
const int req,
const int N,
const int H,
const int W,
const int C,
const float std_d0,
const float std_d1,
const float std_d2) {
// We process one image per thread block.
// In 3D case, we have only 1 block i.e., blockIdx.x
// We do not use it.

float std = std_d0;
for (int c = 0; c < C; ++c) {
switch (c) {
case 0 : std = std_d0;
break;
case 1 : std = std_d1;
break;
case 2 : std = std_d2;
break;
}
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(in_grad[c][h][w], req,
out_grad[c][h][w] * (1 / std));
}
}
}
}

// Normalize Backward Kernel for 3D input
template<typename xpu, typename DType>
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
__global__ void NormalizeBackwardCudaKernel(const Tensor<xpu, 4, DType> out_grad,
const Tensor<xpu, 4, DType> in_grad,
const int req,
const int N,
const int H,
const int W,
const int C,
const float std_d0,
const float std_d1,
const float std_d2) {
// We process one image per thread block.
const int n = blockIdx.x;

float std = std_d0;
for (int c = 0; c < C; ++c) {
switch (c) {
case 0 : std = std_d0;
break;
case 1 : std = std_d1;
break;
case 2 : std = std_d2;
break;
}
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(in_grad[n][c][h][w], req,
out_grad[n][c][h][w] * (1 / std));
}
}
}
}

template<typename DType, typename T>
void NormalizeBackwardImplCUDA(mshadow::Stream<gpu> *s,
const T out_grad,
const T in_grad,
const int req,
const float std_d0,
const float std_d1,
const float std_d2) {
int blocks, H, W, C, N;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (std::is_same<T, Tensor<gpu, 3, DType>>::value) {
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
// 3D Input - (C, H, W)
N = 0;
C = out_grad.size(0);
H = out_grad.size(1);
W = out_grad.size(2);
blocks = 1;
} else {
// 4D Input - (N, C, H, W)
N = out_grad.size(0);
C = out_grad.size(1);
H = out_grad.size(2);
W = out_grad.size(3);
blocks = N > 0 ? N : 1;
}
// One block per image.
// Number of threads = (16, 16) is optimal, because,
// computation is minimal and overhead of CUDA preparing
// all threads is minimal.
NormalizeBackwardCudaKernel<gpu, DType>
<<<blocks, dim3(16, 16), 0, stream>>>(out_grad, in_grad,
req, N, H, W, C, std_d0, std_d1, std_d2);
MSHADOW_CUDA_POST_KERNEL_CHECK(NormalizeBackwardCudaKernel);
}

NNVM_REGISTER_OP(_image_to_tensor)
.set_attr<FCompute>("FCompute<gpu>", ToTensorOpForward<gpu>);

Expand Down