Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add higher order derivative for loss function #9070

Merged
merged 12 commits into from
Sep 15, 2022
54 changes: 26 additions & 28 deletions oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ namespace oneflow {
namespace one {

struct BinaryCrossEntropyCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
};

class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureState> {
Expand All @@ -30,46 +32,42 @@ class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureSt
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) { return Maybe<void>::Ok(); }

Maybe<void> BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 3); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();
ctx->has_weight = inputs.size() == 3;

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs[1]); // target
if (ctx->has_weight) {
ctx->SaveTensorForBackward(inputs[2]); // weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
in_grads->resize(ctx->SavedTensors().size());
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
2 + ctx->has_weight); // NOLINT(maybe-need-error-msg)
in_grads->resize(2 + ctx->has_weight);

const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;

if (ctx->SavedTensors().size() == 3) {
const auto& weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));
} else {
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, NullOpt));
if (ctx->input_requires_grad) {
(*in_grads)[0] = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] = JUST(functional::BinaryCrossEntropyLossTargetGrad(dy, input, target, weight));
}
return Maybe<void>::Ok();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ namespace oneflow {
namespace one {

struct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
bool has_pos_weight = false;
};

Expand All @@ -47,53 +49,51 @@ Maybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCa
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>("has_pos_weight"));
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
ctx->has_weight = inputs.size() == 4 || (inputs.size() == 3 && !ctx->has_pos_weight);
ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs[1]); // target

if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight or pos_weight
ctx->SaveTensorForBackward(inputs[2]); // weight or pos_weight
}
if (inputs.size() == 4) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
ctx->SaveTensorForBackward(inputs.at(3)); // pos_weight
ctx->SaveTensorForBackward(inputs[2]); // weight
ctx->SaveTensorForBackward(inputs[3]); // pos_weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
2 + ctx->has_weight + ctx->has_pos_weight); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];

in_grads->resize(ctx->SavedTensors().size());

if (ctx->SavedTensors().size() == 3) {
if (ctx->has_pos_weight) {
const auto& pos_weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, pos_weight));
} else {
const auto& weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, NullOpt));
}
} else if (ctx->SavedTensors().size() == 4) {
const auto& weight = ctx->SavedTensors().at(2);
const auto& pos_weight = ctx->SavedTensors().at(3);
in_grads->at(0) = JUST(
size_t pos_weight_index = ctx->has_weight ? 3 : 2;
auto weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;
auto pos_weight =
ctx->has_pos_weight ? Optional<one::Tensor>(ctx->SavedTensors()[pos_weight_index]) : NullOpt;

if (ctx->input_requires_grad) {
(*in_grads)[0] = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight));
} else {
in_grads->at(0) =
JUST(functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, NullOpt));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] = JUST(functional::BinaryCrossEntropyWithLogitsLossTargetGrad(
dy, input, target, weight, pos_weight));
}

return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits", BinaryCrossEntropyWithLogits);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace oneflow {
namespace one {

struct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_pos_weight = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
};

class BinaryCrossEntropyWithLogitsReduceMean
Expand All @@ -34,25 +34,19 @@ class BinaryCrossEntropyWithLogitsReduceMean
const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->target_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad();

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target
return Maybe<void>::Ok();
Expand All @@ -61,14 +55,20 @@ Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Apply(
const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads size should be equal to 1. ";
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = JUST(VectorAt(out_grads, 0));
const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1));
in_grads->resize(ctx->SavedTensors().size());
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));
in_grads->resize(2);

if (ctx->input_requires_grad) {
(*in_grads)[0] =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad(dy, input, target));
}
return Maybe<void>::Ok();
}

Expand Down
31 changes: 19 additions & 12 deletions oneflow/core/autograd/gradient_funcs/kl_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace oneflow {
namespace one {

struct KLDivLossCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool log_target = false;
};

Expand All @@ -44,25 +45,31 @@ Maybe<void> KLDivLoss::Init(const OpExpr& op) {
}
Maybe<void> KLDivLoss::Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->log_target = JUST(composed_attrs.GetAttr<bool>("log_target"));
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs[1]); // target
return Maybe<void>::Ok();
}
Maybe<void> KLDivLoss::Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
in_grads->resize(2);

CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
in_grads->resize(ctx->SavedTensors().size());
in_grads->at(0) = JUST(functional::KLDivLossGrad(dy, input, target, ctx->log_target));
if (ctx->input_requires_grad) {
(*in_grads)[0] = JUST(functional::KLDivLossGrad(dy, input, target, ctx->log_target));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] = JUST(functional::KLDivLossTargetGrad(dy, input, target, ctx->log_target));
}

return Maybe<void>::Ok();
}
Expand Down
27 changes: 14 additions & 13 deletions oneflow/core/autograd/gradient_funcs/smooth_l1_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace oneflow {
namespace one {

struct SmoothL1LossCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
float beta = 0.0;
};

Expand All @@ -37,13 +38,13 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {

Maybe<void> Capture(SmoothL1LossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->requires_grad = inputs.at(0)->requires_grad(); // prediction
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)

ctx->SaveTensorForBackward(inputs.at(0)); // prediction
ctx->SaveTensorForBackward(inputs.at(1)); // label
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
ctx->target_requires_grad = inputs.at(1)->requires_grad(); // target

ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->beta = JUST(composed_attrs.GetAttr<float>("beta"));
Expand All @@ -52,15 +53,15 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {

Maybe<void> Apply(const SmoothL1LossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
const auto& grad = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta));

const auto& prediction = ctx->SavedTensors().at(0);
const auto& label = ctx->SavedTensors().at(1);
in_grads->at(0) =
JUST(functional::SmoothL1LossGrad(out_grads.at(0), prediction, label, ctx->beta));
if (ctx->input_requires_grad) { (*in_grads)[0] = grad; }
if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::Negative(grad)); }
return Maybe<void>::Ok();
}

Expand Down
Loading