-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Support pure fp16 training for AMP API. #29544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
Xreki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Great work~
zhiqiu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for unused_var_check.cc
swtkiwi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
辛苦后续继续补齐文档中的示例代码,以及增加中文文档~~
| adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>, | ||
| ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>); | ||
|
|
||
| REGISTER_OP_VERSION(adam) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adam属于训练的op,其实没有必要设置op version,加了也没有什么影响
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,谢谢提醒。
* add cast ops before and after unsupported fp16 ops. * Keep partial net in FP32 pattern. * Support check_finite_and_unscale and update_loss_scaling for FP16 calculation mode. * Add fp16 support for adam op. * add multi precision attr for adam. * Fix the bug of test_multi_precision_fp16_train UT. * Code format for CI. * Fix the redefine error about MPTypeTrait on windows. * fix bugs of the _create_accumulators func in Momentum. * fix bug when inserting post cast op. * Add the update_loss_scaling op in allow_set of UnusedVarCheck. * Update for ci coverage. * Add some doc for OptimizerWithMixedPrecision. * Fix the code style. * Imporve the doc of `amp_init`. * Change for fp16 testing if users have the infer program defined in separate way.
* Support pure fp16 training for AMP API. (#29544) * add cast ops before and after unsupported fp16 ops. * Keep partial net in FP32 pattern. * Support check_finite_and_unscale and update_loss_scaling for FP16 calculation mode. * Add fp16 support for adam op. * add multi precision attr for adam. * Fix the bug of test_multi_precision_fp16_train UT. * Code format for CI. * Fix the redefine error about MPTypeTrait on windows. * fix bugs of the _create_accumulators func in Momentum. * fix bug when inserting post cast op. * Add the update_loss_scaling op in allow_set of UnusedVarCheck. * Update for ci coverage. * Add some doc for OptimizerWithMixedPrecision. * Fix the code style. * Imporve the doc of `amp_init`. * Change for fp16 testing if users have the infer program defined in separate way. * Remove tensor copy in the update_loss_scaling op. (#29426) * remove tensor copy in the update_loss_scaling op * not use thrust. * fix some cuda memory access error.
PR types
New features
PR changes
Others
Describe
1. Background
In the previous AMP implementation, it uses the blac&white list to control the float16 computation. However, this strategy has two shortcomings:
castops may lead to some overheads, which can be 5% ~ 10%.So, we develop the
pure fp16 trainingstrategy, which uses float16 kernels as much as possible.2. API Function
As shown above, in order to integrate
pure fp16 trainingintodecorateAPI, we add two new parameters.2.1 Description of
use_pure_fp16parameterWhen the parameter
use_pure_fp16is set toTrue, it will use float16 kernels as many as possible. Otherwise, it will adopt the black&white list based strategy.2.2 Description of
use_fp16_guardparameter andfp16_guardAPIThe second new parameter
use_fp16_guardcan control the part of float16 computation. Whenuse_fp16_guardis set toFalse, all of the operators used in the user-defined model will be transformed as float16 type except for those inunsupported_fp16_list. Whenuse_fp16_guardis set toTrue, only those ops created in the context managerfp16_guardwill be transformed as float16 type. By default, theuse_fp16_guardis set toNone, which means that its value is equal touse_pure_fp16.2.3 Details about
custom_black_listWhat's more, if users don't want to transform some op types as float16, they can define them in
custom_black_list. If users set thecustom_black_list, these ops incustom_black_listwill keep in the float32 computation type whether they useuse_fp16_guardor not.2.4 Description of
amp_initAPIWhen users choose pure fp16 training, they should use
amp_initAPI to initialize float16 parameters, as shown below.Parameters defined in API 3) are described below:
placeis used to initialize fp16 parameters with fp32 values.scopeis used to find fp32 parameters.use_fp16_testindicates whether to use fp16 testing.Previously, the
black&white list based strategyjust transform the training program, not including the testing program. But for now,pure fp16 trainingalso needs to transform the testing program because there is no float32 parameter in the training and testing process. If users have usedpure fp16 training, the testing program should be passed intoamp_initif users want to perform the testing process.The
use_fp16_testis mainly used to control whether to transform the testing program as float16 type inblack & white list based AMP strategy, and it makes no effect onpure fp16 training. In other word, if users choose thepure fp16 trainingand pass thetest_programintoamp_initAPI, thetest_programwill be transformed as float16 type ignore theuse_fp16_testvalue.2.5 Low-level APIs
cast_model_to_fp16andcast_parameters_to_fp16are two low-level APIs. In most cases, users don't need to use them, and just usedecorateAPI.cast_model_to_fp16
The parameter
programis the program to be cast into fp16. The meaning ofamp_listsanduse_fp16_guardis the same as the definition indecorate. In the special case described as following, the user may need to use this API. If users have used thedecorateAPI to complete pure fp16 training. And they don't usesave_inference_modelandload_inference_modelto do the inference. On the contrary, they define a new inference program and load the pre-trained weights. In this case, they should cast the defined inference program to fp16 bycast_model_to_fp16API. Meanwhile, they need to ensure that the values ofamp_listsanduse_fp16_guardare the same as in the previous pure fp16 training. If the user setsuse_fp16_guardto True, they should usefp16_guardin the same place as in the previous pure fp16 training when building the inference program.cast_parameters_to_fp16
The parameter
programis the model to be processed. Theplaceis used to restore the fp16 weight tensors and thescopeis used to get the fp32 weight tensor values. Only the data types of vars into_fp16_var_nameswill be set to FP16. Usually,to_fp16_var_namesis the returned value of thecast_model_to_fp16API.By now,
cast_parameters_to_fp16has no use case, we just set it aside for future special use.3. Use case
As shown below, the left is the original fp32 computation graph, and the right is the computation graph applied
pure fp16 training.4. Restriction
The
pure fp16 trainingstrategy requires the used optimizer to register the float16 kernel. Until now,Momentum,AdamandAdamWsupport the float16 computation. All three of them have themulti_precisionparameter, which can avoid poor accuracy or slow convergence in a way. In the future, more optimizers will support float16 computation.