Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bw invoke fw #50260

Merged
merged 7 commits into from
Feb 21, 2023
Merged

Support bw invoke fw #50260

merged 7 commits into from
Feb 21, 2023

Conversation

heavyrain-lzy
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy commented Feb 6, 2023

PR types

Others

PR changes

OPs

Describe

本PR解决如下问题:
支持对backward_op调用forward,并且forward含有int_array和scalar的Op自动生成static graph的代码,并修改scale,sign算子。
如scale的yaml配置:

- op : scale
  args : (Tensor x, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true)
  output : Tensor(out)
  infer_meta :
    func : UnchangedInferMeta
    param : [x]
  kernel :
    func : scale {dense -> dense},
           scale_sr {selected_rows -> selected_rows}
    data_type : x
  inplace : (x -> out)
  backward : scale_grad

scale_grad描述:

- backward_op : scale_grad
  forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
  args : (Tensor out_grad, Scalar scale=1.0)
  output : Tensor(x_grad)
  invoke : scale(out_grad, scale, 0.0f, true)

可以看出,scale_grad调用scale,并且scale算子含有Scalar类型参数。
对应自动生成scale_grad的代码如下:

template <typename T>
class ScaleGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("scale");

    grad_op->SetInput("X", this->OutputGrad("Out"));

    grad_op->SetOutput("Out", this->InputGrad("X"));

    if (this->HasInput("ScaleTensor")) {
      grad_op->SetInput("ScaleTensor", this->Input("ScaleTensor"));
    }

    grad_op->SetAttr("scale", this->GetAttr("scale"));
    grad_op->SetAttr("bias", 0.0f);
    grad_op->SetAttr("bias_after_scale", true);
  }
};

可以对齐手写scale_grad实现。

@paddle-bot
Copy link

paddle-bot bot commented Feb 6, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

.pre-commit-config.yaml Outdated Show resolved Hide resolved
Comment on lines 1 to 12
- op : scale
args : (Tensor x, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : scale {dense -> dense},
scale_sr {selected_rows -> selected_rows}
data_type : x
inplace : (x -> out)
backward : scale_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个动态图下的yaml配置有什么区别吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_grad的invoke有一点不同:原始:invoke : scale(out_grad, scale, 0.0f, bias_after_scale),现在:nvoke : scale(out_grad, scale, 0.0f, true)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改动态图和静态图scale的yaml文件一致

@@ -0,0 +1 @@
# This file is to support those static ops different the dynamic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新建的文件没有 op ?是预留的接口吗?

Why no op in this new file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

预留接口文件,后续pr会立即用到

scale_op
generated_static_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么此处用 generated_static_op,其他地方用 generated_op?

Why use generated_static_op here and use generated_op in other places?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为generated_op

Copy link
Contributor

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small questions.

@@ -29,7 +29,7 @@ cc_test_old(
prim_utils
operator
elementwise_mul_op
scale_op
generated_static_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不依赖的还是generated_op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,已经吧generated_static_op删除。

heter_listen_and_serv_op
${RPC_DEPS}
${DISTRIBUTE_DEPS}
eigen_function)

#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc generated_static_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里按说也是generated_op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR把这个注释改过来

Copy link
Contributor

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zyfncg zyfncg merged commit d884573 into PaddlePaddle:develop Feb 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants