Skip to content

op helper function for generating grad var's name #3188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
14 changes: 6 additions & 8 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,18 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
if (AllInSet(forwardOp.inputs_, op_helpers::GenGradName(""), no_grad_names)) {
return NOP();
}

// All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
if (AllInSet(forwardOp.outputs_, op_helpers::GenGradName(""),
no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
// Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
no_grad_names.insert(op_helpers::GenGradName(name));
}
return NOP();
}
Expand Down Expand Up @@ -135,7 +134,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
0, grad_input.size() - op_helpers::GenGradName("").size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();

// If part of input gradient of that operator is not calculated, fill
Expand Down Expand Up @@ -168,11 +167,10 @@ std::shared_ptr<OperatorBase> Backward(
std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size());

no_grad_names.insert(OperatorBase::EMPTY_VAR_NAME() +
OperatorBase::GRAD_VAR_SUFFIX());
no_grad_names.insert(op_helpers::GenGradName(OperatorBase::EMPTY_VAR_NAME()));

for (auto& name : no_grad_vars) {
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
no_grad_names.insert(op_helpers::GenGradName(name));
}
size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid);
Expand Down
37 changes: 15 additions & 22 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
} // namespace paddle

namespace f = paddle::framework;
namespace h = paddle::framework::op_helpers;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
Expand All @@ -165,19 +166,18 @@ TEST(Backward, simple_op_grad) {
ASSERT_EQ(4UL, gop->inputs_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
ASSERT_EQ(h::GenGradName("X"), gop->outputs_[0]);
ASSERT_EQ(h::GenGradName("b"), gop->outputs_[1]);

ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ(h::GenGradName("X"), gop->Output(h::GenGradName("X")));
}

TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"X"});
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
h::GenGradName("X")),
gop->outputs_.end());

auto no_input_gop = f::Backward(*fwd, {"X", "b"});
Expand Down Expand Up @@ -245,21 +245,18 @@ TEST(Backward, net_input_of_network_not_need_grad) {
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());

for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_NE(all_output.find(h::GenGradName(out)), all_output.end());
}

// Not Generated X
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_EQ(all_output.find(h::GenGradName("X")), all_output.end());

ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output(h::GenGradName("A")));
}

TEST(Backward, net_shared_weight) {
Expand Down Expand Up @@ -317,11 +314,9 @@ TEST(Backward, op_part_of_output_are_not_need) {
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(),
d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
d_many_out.Input(h::GenGradName("z")));
ASSERT_EQ(h::GenGradName("Y"), d_many_out.Input(h::GenGradName("y")));
ASSERT_EQ(h::GenGradName("X"), d_many_out.Output(h::GenGradName("x")));
}

TEST(Backward, op_part_of_input_are_not_need) {
Expand All @@ -331,12 +326,10 @@ TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
ASSERT_EQ(grad_mul.Output(h::GenGradName("A")),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Output(h::GenGradName("B")), h::GenGradName("b"));
ASSERT_EQ(grad_mul.Input(h::GenGradName("Out")), h::GenGradName("out"));
ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out");
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/grad_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
for (const auto& arg : src_arg_list) {
std::string src_name = arg.name();
std::string dst_name =
is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name;
is_grad ? op_helpers::GenGradName(src_name) : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin =
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx);
int src_end = src_format == nullptr ? src_arg_idx + 1
: src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) {
std::string s = is_grad ? src_inout[i] + OperatorBase::GRAD_VAR_SUFFIX()
std::string s = is_grad ? op_helpers::GenGradName(src_inout[i])
: arg.ignore_gradient()
? OperatorBase::EMPTY_VAR_NAME()
: src_inout[i];
Expand Down
62 changes: 28 additions & 34 deletions paddle/framework/grad_op_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
} // namespace paddle

namespace f = paddle::framework;
namespace op_helpers = paddle::framework::op_helpers;

TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<f::OperatorBase> add_op(
Expand Down Expand Up @@ -83,24 +84,21 @@ TEST(GradOpBuilder, MutiInOut) {
EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_test_op->Input("Out1" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out1" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(
grad_test_op->Inputs("Out2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>(
{"out2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"out2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Input(op_helpers::GenGradName("Out1")),
op_helpers::GenGradName("out1"));
EXPECT_EQ(grad_test_op->Inputs(op_helpers::GenGradName("Out2_mult")),
std::vector<std::string>({op_helpers::GenGradName("out2_1"),
op_helpers::GenGradName("out2_2")}));

ASSERT_EQ(grad_test_op->outputs_.size(), 5UL);
EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"in1" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(
grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in2_3" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Output("In3" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"in3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_test_op->Output(op_helpers::GenGradName("In1")),
op_helpers::GenGradName("in1"));
EXPECT_EQ(grad_test_op->Outputs(op_helpers::GenGradName("In2_mult")),
std::vector<std::string>({op_helpers::GenGradName("in2_1"),
op_helpers::GenGradName("in2_2"),
op_helpers::GenGradName("in2_3")}));
EXPECT_EQ(grad_test_op->Output(op_helpers::GenGradName("In3")),
op_helpers::GenGradName("in3"));
}

TEST(GradOpBuilder, IOIgnoredInGradient) {
Expand All @@ -122,24 +120,20 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_test_op->Input("Out2"), f::OperatorBase::EMPTY_VAR_NAME());
EXPECT_EQ(
grad_test_op->Inputs("Out1_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>(
{"out1_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"out1_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Input("Out2" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out2" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_test_op->Input("Out2"), op_helpers::GenGradName(""));
EXPECT_EQ(grad_test_op->Inputs(op_helpers::GenGradName("Out1_mult")),
std::vector<std::string>({op_helpers::GenGradName("out1_1"),
op_helpers::GenGradName("out1_2")}));
EXPECT_EQ(grad_test_op->Input(op_helpers::GenGradName("Out2")),
op_helpers::GenGradName("out2"));

ASSERT_EQ(grad_test_op->outputs_.size(), 5UL);
EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"in1" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(
grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(
grad_test_op->Outputs("In3_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>({"in3_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in3_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Output(op_helpers::GenGradName("In1")),
op_helpers::GenGradName("in1"));
EXPECT_EQ(grad_test_op->Outputs(op_helpers::GenGradName("In2_mult")),
std::vector<std::string>({op_helpers::GenGradName("in2_1"),
op_helpers::GenGradName("in2_2")}));
EXPECT_EQ(grad_test_op->Outputs(op_helpers::GenGradName("In3_mult")),
std::vector<std::string>({op_helpers::GenGradName("in3_1"),
op_helpers::GenGradName("in3_2")}));
}
37 changes: 37 additions & 0 deletions paddle/framework/op_helpers/op_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

namespace paddle {
namespace framework {
namespace op_helpers {

/*
* Generate the gradient variable's name of a forward varialbe.
*
* If a variable's name has a certain suffix, it means that the
* variable is the gradient of another varibale.
* e.g. Variable "x@GRAD" is the gradient of varibale "x".
*/
inline std::string GenGradName(const std::string& var) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要为这么一个函数创建一个namespace吗?是很快就会有其他函数被加入进来吗?如果不是,就放到 framework namespace里是不是“最简单”的做法?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,应该很快会加op inputs/outputs 一些check的helper function,简化一部分 InferShape 里的写法

static const std::string suffix{"@GRAD"};
return var + suffix;
}

} // namespace op_helpers
} // namespace framework
} // namespace paddle
10 changes: 1 addition & 9 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */

#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_helpers/op_helpers.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
Expand Down Expand Up @@ -50,15 +51,6 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }

static std::string GRAD_VAR_NAME(const std::string& name) {
return name + GRAD_VAR_SUFFIX();
}

/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }

Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MeanOpMaker : public OpProtoAndCheckerMaker {
class MeanGradOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX())
ctx.Output<Tensor>(op_helpers::GenGradName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/mean_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ template <typename Place, typename T>
class MeanGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX());
auto OG = context.Input<Tensor>(op_helpers::GenGradName("Out"));
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX());
auto IG = context.Output<Tensor>(op_helpers::GenGradName("X"));
IG->mutable_data<T>(context.GetPlace());

T ig_size = (T)framework::product(IG->dims());
Expand Down
9 changes: 5 additions & 4 deletions paddle/operators/softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ class SoftmaxOpGrad : public OperatorWithKernel {
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr,
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.InputVar(op_helpers::GenGradName("Y")) != nullptr,
"Input(%s) should not be null",
op_helpers::GenGradName("Y"));
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(GRAD_VAR_NAME("Y"))->dims(),
ctx.Input<Tensor>(op_helpers::GenGradName("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same");
ctx.Output<Tensor>(GRAD_VAR_NAME("X"))
ctx.Output<Tensor>(op_helpers::GenGradName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
}
};
Expand Down
24 changes: 11 additions & 13 deletions paddle/operators/softmax_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,19 @@ class SoftmaxKernel : public OpKernel {
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);

auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
auto shifted_logits = (logits - logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));

softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();

softmax.device(context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
(softmax * softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};

Expand All @@ -68,8 +66,8 @@ class SoftmaxGradKernel : public OpKernel {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();

auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(OperatorBase::GRAD_VAR_NAME("Y"));
auto dX = context.Output<Tensor>(OperatorBase::GRAD_VAR_NAME("X"));
auto dY = context.Input<Tensor>(op_helpers::GenGradName("Y"));
auto dX = context.Output<Tensor>(op_helpers::GenGradName("X"));
dX->mutable_data<T>(context.GetPlace());

const int batch_size = Y->dims()[0];
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/type_alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_helpers/op_helpers.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"

Expand Down Expand Up @@ -58,3 +59,4 @@ using OpRegistry = framework::OpRegistry;
} // namespace paddle

namespace ops = paddle::operators;
namespace op_helpers = paddle::framework::op_helpers;