-
Couldn't load subscription status.
- Fork 5.9k
Add Inplace strategy (Output reuse Input Varbase) in dygraph #30103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Inplace strategy (Output reuse Input Varbase) in dygraph #30103
Conversation
… add_view_strategy
… add_view_strategy
|
Thanks for your contribution! |
… add_inplace_strategy
da98607 to
32752d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| !shared_node || !grad_node->InplaceGradNameMap().empty(), true, | ||
| platform::errors::PermissionDenied( | ||
| "Cannot set gradient op twice unless using Inplace Strategy.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better not write combined condition in PADDLE_ENFORCE_EQ , You can refine it in next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will refine it in next PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…e) in dygraph (PaddlePaddle#30103) * add view strategy on squeeze,unsqueeze,reshape,flatten * add squeeze unittest * add unittests * use View strategy as name rather than Reuse Allacation * fix view api doc * fix format * use core.ops when input of reshape2 is Tensor * fix test_cross_entropy_loss error because of reshape2 * fix test_cross_entropy_loss error because of reshape2 * add inplace strategy * add elementwise_add sub * let backward op not use inplace * grad op do not use inplace * fix memory increase error and add leaf error message * delete selected_rows * change op_function * little change * solve HandleViewBetweenInputAndOutput * add unittest and leaf error message * merge view error * optimize op_function_generator format and support sum inplace op * fix format of basic_engine * fix format for framework * little change of variable wrapper * add reshape, squeeze, unsqueeze, scatter api * add relu elu tanh softmax inplace api * fix test_squeeze_op unittest * fix test_relu_op unittest * fix comment problems * delete sample code of inplace api * add reference of grad_pending_nodes in basic_engine * fix unittest name * add inplace apis into wlist * fix error message * add PADDLE_ENFORCE for set grad op twice * fix head file error
…e) in dygraph (#30103) (#30496) * add view strategy on squeeze,unsqueeze,reshape,flatten * add squeeze unittest * add unittests * use View strategy as name rather than Reuse Allacation * fix view api doc * fix format * use core.ops when input of reshape2 is Tensor * fix test_cross_entropy_loss error because of reshape2 * fix test_cross_entropy_loss error because of reshape2 * add inplace strategy * add elementwise_add sub * let backward op not use inplace * grad op do not use inplace * fix memory increase error and add leaf error message * delete selected_rows * change op_function * little change * solve HandleViewBetweenInputAndOutput * add unittest and leaf error message * merge view error * optimize op_function_generator format and support sum inplace op * fix format of basic_engine * fix format for framework * little change of variable wrapper * add reshape, squeeze, unsqueeze, scatter api * add relu elu tanh softmax inplace api * fix test_squeeze_op unittest * fix test_relu_op unittest * fix comment problems * delete sample code of inplace api * add reference of grad_pending_nodes in basic_engine * fix unittest name * add inplace apis into wlist * fix error message * add PADDLE_ENFORCE for set grad op twice * fix head file error
PR types
New features
PR changes
APIs
Describe
动态图支持Inplace策略:不新建输出VarBase,直接将输入VarBase传递给输出。与静态图已添加Inplace策略的op对齐。为与该策略相关的OP新增python API接口。
支持的OP
与静态图已支持Inplace策略的OP对齐,一共有38个。在这个PR中,为
squeeze2、unsqueeze2、reshape2、scatterelu、relu、softmax、tanh8个OP添加了Inplace策略。其它OP后续会添加Inplace支持。API接口
为需要使用Inplace策略的OP新增python API接口,API名称为现有API名称加上下划线_,用户可选择是否使用这类API组网。
对上述8个OP,对应添加python API接口:
squeeze_、unsqueeze_、reshape_、scatter_elu_、relu_、softmax_、tanh_。生成的op function
在op_function_generator层面添加动态图Inplace机制,以squeeze2 op为例,添加Inplace策略后,生成的op function为
imperative_squeeze2_。主要改变为:
stop_gradient=False的叶子节点,如果是,则会报错。inplace_version增加1。squeeze_()API使用示例Inplace报错
stop_gradient=False的叶子节点实现要点