Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AMP OP&Test】instance_norm fp16 and bf16 support. #52241

Merged
merged 24 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
da9fbbd
add fp16 and bf16 support for instance_norm
qizhaoaoe Mar 28, 2023
2b7111c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qizhaoaoe Mar 28, 2023
4b414d8
fix /= operator which not support bf16
qizhaoaoe Mar 29, 2023
e491361
fix instance_norm_grad kernel and unittests.
qizhaoaoe Mar 30, 2023
134fbcc
fix fp32 unittests.
qizhaoaoe Mar 30, 2023
ecd7ae1
fix instance_norm_kernel and unittests.
qizhaoaoe Mar 31, 2023
0006187
fix instance_norm_grad_kernel and unittest threshold.
qizhaoaoe Apr 1, 2023
e0af6d2
add fp16/bf16 for instance_norm_grad_grad op.
qizhaoaoe Apr 1, 2023
f62fd45
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qizhaoaoe Apr 1, 2023
a48f6a0
add bf16 dtype check.
qizhaoaoe Apr 1, 2023
9f5a3a9
fix conflicts.
qizhaoaoe Apr 3, 2023
db1703d
fix cpu support for fp32 op and fix type in instance_norm_grad_kernel.
qizhaoaoe Apr 3, 2023
0e25977
fix type in instance_norm_kernel.
qizhaoaoe Apr 3, 2023
40ccc84
fix bf16 outputs in unittests and refine codes.
qizhaoaoe Apr 3, 2023
89947c1
fix dx computation.
qizhaoaoe Apr 4, 2023
fdb4f4a
delete unuseful params and head including.
qizhaoaoe Apr 4, 2023
248e9c3
add fp16/bf16 for static graph.
qizhaoaoe Apr 4, 2023
b012cd6
fix device condiction for instance_norm op.
qizhaoaoe Apr 4, 2023
6d9dd8d
fix instance_norm_grad_grad and bf16 op tests.
qizhaoaoe Apr 4, 2023
1a674b5
fix op_test to support grad of bf16 can be compared with fp32.
qizhaoaoe Apr 6, 2023
ae01635
Merge branch 'develop' into instance_norm_amp
qizhaoaoe Apr 6, 2023
a4d9453
remove updates.
qizhaoaoe Apr 7, 2023
c62e9b6
add self-defined grad.
qizhaoaoe Apr 7, 2023
c497cac
Merge remote-tracking branch 'upstream/develop' into instance_norm_amp
qizhaoaoe Apr 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix device condiction for instance_norm op.
  • Loading branch information
qizhaoaoe committed Apr 4, 2023
commit b012cd6b23d2523a9f82b05fbf603f60b9947c46
21 changes: 17 additions & 4 deletions paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,28 @@ PD_REGISTER_KERNEL(instance_norm_grad,
ALL_LAYOUT,
phi::InstanceNormGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(instance_norm_double_grad,
GPU,
ALL_LAYOUT,
phi::InstanceNormDoubleGradKernel,
float,
phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(instance_norm_grad,
GPU,
ALL_LAYOUT,
phi::InstanceNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(instance_norm_double_grad,
GPU,
ALL_LAYOUT,
phi::InstanceNormDoubleGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
Expand All @@ -652,14 +667,12 @@ PD_REGISTER_KERNEL(instance_norm_grad,
phi::InstanceNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::float16) {}
PD_REGISTER_KERNEL(instance_norm_double_grad,
GPU,
ALL_LAYOUT,
phi::InstanceNormDoubleGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::float16) {}
#endif
11 changes: 9 additions & 2 deletions paddle/phi/kernels/gpu/instance_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ PD_REGISTER_KERNEL(instance_norm,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(instance_norm,
GPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
Expand All @@ -240,6 +248,5 @@ PD_REGISTER_KERNEL(instance_norm,
phi::InstanceNormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::float16) {}
#endif