Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AMP OP&Test】instance_norm fp16 and bf16 support. #52241

Merged
merged 24 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
da9fbbd
add fp16 and bf16 support for instance_norm
qizhaoaoe Mar 28, 2023
2b7111c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qizhaoaoe Mar 28, 2023
4b414d8
fix /= operator which not support bf16
qizhaoaoe Mar 29, 2023
e491361
fix instance_norm_grad kernel and unittests.
qizhaoaoe Mar 30, 2023
134fbcc
fix fp32 unittests.
qizhaoaoe Mar 30, 2023
ecd7ae1
fix instance_norm_kernel and unittests.
qizhaoaoe Mar 31, 2023
0006187
fix instance_norm_grad_kernel and unittest threshold.
qizhaoaoe Apr 1, 2023
e0af6d2
add fp16/bf16 for instance_norm_grad_grad op.
qizhaoaoe Apr 1, 2023
f62fd45
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qizhaoaoe Apr 1, 2023
a48f6a0
add bf16 dtype check.
qizhaoaoe Apr 1, 2023
9f5a3a9
fix conflicts.
qizhaoaoe Apr 3, 2023
db1703d
fix cpu support for fp32 op and fix type in instance_norm_grad_kernel.
qizhaoaoe Apr 3, 2023
0e25977
fix type in instance_norm_kernel.
qizhaoaoe Apr 3, 2023
40ccc84
fix bf16 outputs in unittests and refine codes.
qizhaoaoe Apr 3, 2023
89947c1
fix dx computation.
qizhaoaoe Apr 4, 2023
fdb4f4a
delete unuseful params and head including.
qizhaoaoe Apr 4, 2023
248e9c3
add fp16/bf16 for static graph.
qizhaoaoe Apr 4, 2023
b012cd6
fix device condiction for instance_norm op.
qizhaoaoe Apr 4, 2023
6d9dd8d
fix instance_norm_grad_grad and bf16 op tests.
qizhaoaoe Apr 4, 2023
1a674b5
fix op_test to support grad of bf16 can be compared with fp32.
qizhaoaoe Apr 6, 2023
ae01635
Merge branch 'develop' into instance_norm_amp
qizhaoaoe Apr 6, 2023
a4d9453
remove updates.
qizhaoaoe Apr 7, 2023
c62e9b6
add self-defined grad.
qizhaoaoe Apr 7, 2023
c497cac
Merge remote-tracking branch 'upstream/develop' into instance_norm_amp
qizhaoaoe Apr 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix instance_norm_grad_kernel and unittest threshold.
  • Loading branch information
qizhaoaoe committed Apr 1, 2023
commit 0006187fe7f91687fe2350a83dcde2b5a1e71b28
87 changes: 49 additions & 38 deletions paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,19 @@
namespace phi {
template <typename T, typename AccT, int BlockDim>
static __global__ void GradComputeDX(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean,
const BatchNormParamType<AccT> *scale,
const BatchNormParamType<AccT> *mean,
const T *x,
const BatchNormParamType<T> *variance,
const BatchNormParamType<AccT> *variance,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以看下BatchNormParamType的定义,这里不需要再使用AccT,直接用T就可以,BatchNormParamType(T==fp16orbf16)就是float了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it.

const int C,
const int sample_size,
T *dx) {
int beg_idx = blockIdx.x * sample_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * sample_size;
int ncid = blockIdx.x;
int c = ncid % C;
BatchNormParamType<AccT> mean_val =
static_cast<BatchNormParamType<AccT>>(mean[ncid]);
BatchNormParamType<AccT> inv_var_val =
static_cast<BatchNormParamType<AccT>>(variance[ncid]);
BatchNormParamType<AccT> mean_val = mean[ncid];
BatchNormParamType<AccT> inv_var_val = variance[ncid];
typedef cub::BlockReduce<BatchNormParamType<AccT>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
Expand Down Expand Up @@ -323,8 +321,8 @@ void InstanceNormGradKernel(const Context &dev_ctx,

dev_ctx.template Alloc<T>(d_x);
if (d_scale && d_bias) {
dev_ctx.template Alloc<T>(d_scale);
dev_ctx.template Alloc<T>(d_bias);
dev_ctx.template Alloc<AccT>(d_scale);
dev_ctx.template Alloc<AccT>(d_bias);
}
if (scale_ptr) {
PADDLE_ENFORCE_EQ(
Expand All @@ -349,7 +347,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
scale_ptr->dims()));
}

phi::funcs::SetConstant<GPUContext, T> set_constant;
phi::funcs::SetConstant<GPUContext, AccT> set_constant;

const int n = x.numel();
const int block = 512;
Expand All @@ -360,35 +358,36 @@ void InstanceNormGradKernel(const Context &dev_ctx,

DenseTensor scale_tmp;
scale_tmp.Resize({NxC});
dev_ctx.template Alloc<T>(&scale_tmp);
dev_ctx.template Alloc<AccT>(&scale_tmp);

DenseTensor d_scale_tmp;
d_scale_tmp.Resize({NxC});
dev_ctx.template Alloc<T>(&d_scale_tmp);
dev_ctx.template Alloc<AccT>(&d_scale_tmp);

DenseTensor d_bias_tmp;
d_bias_tmp.Resize({NxC});
dev_ctx.template Alloc<T>(&d_bias_tmp);

dev_ctx.template Alloc<AccT>(&d_bias_tmp);
VLOG(0) << "break";
if (scale_ptr) {
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
scale_ptr->data<T>(), scale_tmp.data<T>(), N, C);
repeat_param<AccT><<<grid, block, 0, dev_ctx.stream()>>>(
scale_ptr->data<AccT>(), scale_tmp.data<AccT>(), N, C);
} else {
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
set_constant(dev_ctx, &scale_tmp, static_cast<AccT>(1));
}

VLOG(0) << "break";
std::vector<int> dims;
std::vector<int> strides;
dims = {1, NxC, H, W, D};
strides = {NxC * H * W * D, H * W * D, W * D, D, 1};

if ((H * W * D) == 1) {
phi::Copy(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x);
phi::funcs::SetConstant<GPUContext, BatchNormParamType<T>> functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
phi::funcs::SetConstant<GPUContext, BatchNormParamType<AccT>> functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<AccT>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<AccT>>(0));
return;
}
VLOG(0) << "break";

#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_;
Expand All @@ -414,6 +413,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
VLOG(0) << "break";

#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSetTensorDescriptor(
Expand All @@ -434,11 +434,13 @@ void InstanceNormGradKernel(const Context &dev_ctx,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor(
in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL));
#endif

VLOG(0) << "break";
const auto *saved_mean_data =
saved_mean.template data<BatchNormParamType<T>>();
saved_mean.template data<BatchNormParamType<AccT>>();
const auto *saved_var_data =
saved_variance.template data<BatchNormParamType<T>>();
saved_variance.template data<BatchNormParamType<AccT>>();
VLOG(0) << "break";

if (d_scale && d_bias) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenBatchNormalizationBackward(
Expand All @@ -455,9 +457,9 @@ void InstanceNormGradKernel(const Context &dev_ctx,
data_desc_,
d_x->template data<T>(),
in_param_desc_,
scale_tmp.template data<BatchNormParamType<T>>(),
d_scale_tmp.template data<BatchNormParamType<T>>(),
d_bias_tmp.template data<BatchNormParamType<T>>(),
scale_tmp.template data<BatchNormParamType<AccT>>(),
d_scale_tmp.template data<BatchNormParamType<AccT>>(),
d_bias_tmp.template data<BatchNormParamType<AccT>>(),
epsilon,
saved_mean_data,
saved_var_data));
Expand All @@ -476,18 +478,19 @@ void InstanceNormGradKernel(const Context &dev_ctx,
data_desc_,
d_x->template data<T>(),
in_param_desc_,
scale_tmp.template data<BatchNormParamType<T>>(),
d_scale_tmp.template data<BatchNormParamType<T>>(),
d_bias_tmp.template data<BatchNormParamType<T>>(),
scale_tmp.template data<BatchNormParamType<AccT>>(),
d_scale_tmp.template data<BatchNormParamType<AccT>>(),
d_bias_tmp.template data<BatchNormParamType<AccT>>(),
epsilon,
saved_mean_data,
saved_var_data));
#endif
} else {
if (d_x) {
VLOG(0) << "gradComputeDx";
GradComputeDX<T, AccT, block><<<NxC, block, 0, dev_ctx.stream()>>>(
d_y.data<T>(),
scale_tmp.data<BatchNormParamType<T>>(),
scale_tmp.data<BatchNormParamType<AccT>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
Expand All @@ -496,13 +499,14 @@ void InstanceNormGradKernel(const Context &dev_ctx,
d_x->data<T>());
}
}

VLOG(0) << "add d_scale and d_bias";
if (d_scale && d_bias) {
add_param<T, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
d_scale_tmp.data<T>(), d_scale->data<T>(), N, C);
add_param<T, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
d_bias_tmp.data<T>(), d_bias->data<T>(), N, C);
add_param<AccT, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
d_scale_tmp.data<AccT>(), d_scale->data<AccT>(), N, C);
add_param<AccT, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
d_bias_tmp.data<AccT>(), d_bias->data<AccT>(), N, C);
}
VLOG(0) << "after add";

#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down Expand Up @@ -580,6 +584,7 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx,
epsilon,
dx_data);
}
VLOG(0) << "double kernel";
if (dscale) {
DenseTensor dscale_tmp;
dscale_tmp.Resize({NxC});
Expand Down Expand Up @@ -618,13 +623,19 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx,
epsilon,
ddy_data);
}
VLOG(0) << "double finished";
}
} // namespace phi

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(
instance_norm_grad, GPU, ALL_LAYOUT, phi::InstanceNormGradKernel, float) {}
PD_REGISTER_KERNEL(instance_norm_grad,
GPU,
ALL_LAYOUT,
phi::InstanceNormGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(instance_norm_double_grad,
GPU,
ALL_LAYOUT,
Expand Down
47 changes: 26 additions & 21 deletions paddle/phi/kernels/gpu/instance_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,45 +131,45 @@ void InstanceNormKernel(const Context &dev_ctx,
const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min((NxC + block - 1) / block, max_blocks);

phi::funcs::SetConstant<GPUContext, T> set_constant;
phi::funcs::SetConstant<GPUContext, AccT> set_constant;
if (scale_ptr) {
repeat_param<AccT><<<grid, block, 0, dev_ctx.stream()>>>(
scale_ptr->data<AccT>(), scale_tmp.data<AccT>(), N, C);
} else {
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
set_constant(dev_ctx, &scale_tmp, static_cast<AccT>(1));
}
if (bias_ptr) {
repeat_param<AccT><<<grid, block, 0, dev_ctx.stream()>>>(
bias_ptr->data<AccT>(), bias_tmp.data<AccT>(), N, C);
} else {
set_constant(dev_ctx, &bias_tmp, static_cast<T>(0));
set_constant(dev_ctx, &bias_tmp, static_cast<AccT>(0));
}

auto handle = dev_ctx.cudnn_handle();

DenseTensor saved_mean_tmp, saved_variance_tmp;
phi::funcs::SetConstant<GPUContext, BatchNormParamType<T>> functor;
phi::funcs::SetConstant<GPUContext, BatchNormParamType<AccT>> functor;

if (saved_mean) {
dev_ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
dev_ctx.template Alloc<BatchNormParamType<AccT>>(saved_mean);
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<AccT>>(0));
} else {
saved_mean_tmp = phi::Full<BatchNormParamType<T>>(
dev_ctx, {NxC}, static_cast<BatchNormParamType<T>>(0));
saved_mean_tmp = phi::Full<BatchNormParamType<AccT>>(
dev_ctx, {NxC}, static_cast<BatchNormParamType<AccT>>(0));
}
if (saved_variance) {
dev_ctx.template Alloc<BatchNormParamType<T>>(saved_variance);
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
dev_ctx.template Alloc<BatchNormParamType<AccT>>(saved_variance);
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<AccT>>(0));
} else {
saved_variance_tmp = phi::Full<BatchNormParamType<T>>(
dev_ctx, {NxC}, static_cast<BatchNormParamType<T>>(0));
saved_variance_tmp = phi::Full<BatchNormParamType<AccT>>(
dev_ctx, {NxC}, static_cast<BatchNormParamType<AccT>>(0));
}

auto *saved_mean_data = saved_mean
? saved_mean->data<BatchNormParamType<T>>()
: saved_mean_tmp.data<BatchNormParamType<T>>();
? saved_mean->data<BatchNormParamType<AccT>>()
: saved_mean_tmp.data<BatchNormParamType<AccT>>();
auto *saved_variance_data =
saved_variance ? saved_variance->data<BatchNormParamType<T>>()
: saved_variance_tmp.data<BatchNormParamType<T>>();
saved_variance ? saved_variance->data<BatchNormParamType<AccT>>()
: saved_variance_tmp.data<BatchNormParamType<AccT>>();

#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
Expand All @@ -186,9 +186,9 @@ void InstanceNormKernel(const Context &dev_ctx,
static_cast<void *>(y->template data<T>()),
in_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale_tmp.template data<BatchNormParamType<T>>())),
scale_tmp.template data<BatchNormParamType<AccT>>())),
const_cast<void *>(static_cast<const void *>(
bias_tmp.template data<BatchNormParamType<T>>())),
bias_tmp.template data<BatchNormParamType<AccT>>())),
0,
nullptr,
nullptr,
Expand Down Expand Up @@ -232,8 +232,13 @@ void InstanceNormKernel(const Context &dev_ctx,

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(
instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {}
PD_REGISTER_KERNEL(instance_norm,
GPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(instance_norm,
GPU,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def test_check_output(self):
self.check_output_with_place(place, atol=self.atol)

def test_check_grad(self):
self.check_grad(
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X', 'Scale', 'Bias'],
'Y',
max_relative_error=self.max_relative_error,
Expand Down Expand Up @@ -201,7 +203,7 @@ def init_dtype(self):

def set_err_thre(self):
self.atol = 0.03125
self.max_relative_error = 5e-3
self.max_relative_error = 8e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用默认值都无法通过吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的



@unittest.skipIf(
Expand Down