Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export fused_scale_tril op #5933

Merged
merged 24 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e5fb316
FusedScaleTrilFunctor
leaves-zwx Aug 18, 2021
3ec2dc9
grad func
leaves-zwx Aug 18, 2021
e8abcf9
add test
leaves-zwx Aug 18, 2021
d359a2b
fix bugs
leaves-zwx Aug 18, 2021
4934083
Merge branch 'master' into f_fused_scale_tril
leaves-zwx Aug 18, 2021
809dd6b
rm comment
leaves-zwx Aug 18, 2021
591ba65
Merge branch 'master' into f_fused_scale_tril
leaves-zwx Aug 18, 2021
e6909bc
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
1f5cebf
fix test
leaves-zwx Aug 18, 2021
c01ee3c
Merge branch 'f_fused_scale_tril' of https://github.com/Oneflow-Inc/o…
leaves-zwx Aug 18, 2021
a1ee72a
Merge branch 'master' into f_fused_scale_tril
leaves-zwx Aug 18, 2021
5a79c87
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
f7ea097
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
75371f1
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
482fd7e
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
a3984ab
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
b83396b
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
cb0cc0e
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
fab3eaf
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 18, 2021
402aed6
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 19, 2021
095929e
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 19, 2021
77ea57b
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 19, 2021
3383424
fix conflict
leaves-zwx Aug 19, 2021
9b58939
Merge branch 'master' into f_fused_scale_tril
oneflow-ci-bot Aug 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions oneflow/core/autograd/gradient_funcs/fused_scale_tril.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct FusedScaleTrilState : public AutoGradCaptureState {
bool requires_grad;
int64_t diagonal;
double floating_scale_value;
int64_t integer_scale_value;
bool is_floating_scale_value;
};

class FusedScaleTril : public OpExprGradFunction<FusedScaleTrilState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> FusedScaleTril::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> FusedScaleTril::Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>("diagonal"));
ctx->floating_scale_value = JUST(composed_attrs.GetAttr<double>("floating_scale_value"));
ctx->integer_scale_value = JUST(composed_attrs.GetAttr<int64_t>("integer_scale_value"));
ctx->is_floating_scale_value = JUST(composed_attrs.GetAttr<bool>("is_floating_scale_value"));
return Maybe<void>::Ok();
}

Maybe<void> FusedScaleTril::Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
functional::Scalar scale;
if (ctx->is_floating_scale_value) {
scale = ctx->floating_scale_value;
} else {
scale = ctx->integer_scale_value;
}
(*in_grads)[0] = JUST(functional::FusedScaleTril(out_grads[0], ctx->diagonal, 0, scale));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方的第三个参时(0)我不大懂,看老 lazy 的注册梯度,这个 filled_value 其实就是正向 filled_value 属性(float和int都包括)。但是这里直接写死为 0 了。是可以的吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前向,保留下三角矩阵,反向传播时,梯度也保留下三角矩阵。看 tril_op.cpp 里面也是这么写的,反向没传 fill_value (默认是 0)

return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_tril", FusedScaleTril);

} // namespace one
} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1064,3 +1064,7 @@
- name: "consistent_randperm"
signature: "Tensor ConsistentRandperm(Int32 n,*, Placement placement, SbpList sbp_tuple, Generator generator=None)"
bind_python: True

- name: "fused_scale_tril"
signature: "Tensor FusedScaleTril(Tensor x, *, Int64 diagonal=0, Scalar fill_value=0, Scalar scale=1)"
bind_python: True
39 changes: 39 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,44 @@ class L2NormalizeGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class FusedScaleTrilFunctor {
public:
FusedScaleTrilFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fused_scale_tril").Input("in").Output("out").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& diagonal,
const Scalar& fill_value, const Scalar& scale) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("diagonal", diagonal));
bool is_fill_value_double = fill_value.IsFloatingPoint();
bool is_scale_double = scale.IsFloatingPoint();
if (is_fill_value_double) {
JUST(attrs.SetAttr<double>("floating_fill_value", JUST(fill_value.As<double>())));
JUST(attrs.SetAttr<int64_t>("integer_fill_value", 0));
JUST(attrs.SetAttr<bool>("is_floating_fill_value", true));
} else {
JUST(attrs.SetAttr<double>("floating_fill_value", 0));
JUST(attrs.SetAttr<int64_t>("integer_fill_value", JUST(fill_value.As<int64_t>())));
JUST(attrs.SetAttr<bool>("is_floating_fill_value", false));
}

if (is_scale_double) {
JUST(attrs.SetAttr<double>("floating_scale_value", JUST(scale.As<double>())));
JUST(attrs.SetAttr<int64_t>("integer_scale_value", 0));
JUST(attrs.SetAttr<bool>("is_floating_scale_value", true));
} else {
JUST(attrs.SetAttr<double>("floating_scale_value", 0));
JUST(attrs.SetAttr<int64_t>("integer_scale_value", JUST(scale.As<int64_t>())));
JUST(attrs.SetAttr<bool>("is_floating_scale_value", false));
}
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -728,6 +766,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::OneHotFunctor>("OneHot");
m.add_functor<impl::L2NormalizeFunctor>("L2Normalize");
m.add_functor<impl::L2NormalizeGradFunctor>("L2NormalizeGrad");
m.add_functor<impl::FusedScaleTrilFunctor>("FusedScaleTril");
};

} // namespace functional
Expand Down
90 changes: 90 additions & 0 deletions python/oneflow/test/modules/test_fused_scale_tril.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
import numpy as np
from collections import OrderedDict

from test_util import GenArgDict

import oneflow as flow


def _np_tril(x, diagonal, fill_value, scale):
if int(fill_value) == 0:
return np.tril(x, diagonal) * scale

upper = np.empty(x.shape)
upper.fill(fill_value)
upper = np.triu(upper, diagonal + 1)

return np.tril(x, diagonal) * scale + upper


def _test_fused_scale_tril(
test_case,
shape,
diagonal=0,
fill_value=0,
scale=1,
dtype=flow.float32,
device_type="cuda",
):
if dtype is flow.int32 and not isinstance(scale, int):
return

if dtype is flow.int32:
x = np.random.randint(0, 10, shape)
y_grad = np.random.randint(0, 10, shape)
else:
x = np.random.rand(*shape)
y_grad = np.random.rand(*shape)

y = _np_tril(x, diagonal, fill_value, scale)
x_grad = _np_tril(y_grad, diagonal, 0, scale)

flow_x = flow.Tensor(
x, device=flow.device(device_type), dtype=dtype, requires_grad=True
)
flow_y = flow.F.fused_scale_tril(flow_x, diagonal, fill_value, scale)
flow_y_grad = flow.Tensor(y_grad, device=flow.device(device_type), dtype=dtype)
flow_y.backward(flow_y_grad)

flow_y_np = flow_y.numpy()
test_case.assertTrue(np.allclose(flow_y_np, y.astype(flow_y_np.dtype)))

flow_x_grad_np = flow_x.grad.numpy()
test_case.assertTrue(
np.allclose(flow_x_grad_np, x_grad.astype(flow_x_grad_np.dtype))
)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n1d()
class FusedScaleTrilTestCase(flow.unittest.TestCase):
def test_fused_scale_tril(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(5, 5), (4, 6)]
arg_dict["diagonal"] = [-1, 0, 1]
arg_dict["fill_value"] = [-1, 0, 1]
arg_dict["scale"] = [-2.3, 0.7, 2]
arg_dict["dtype"] = [flow.int32, flow.float32]
for kwargs in GenArgDict(arg_dict):
_test_fused_scale_tril(test_case, **kwargs)


if __name__ == "__main__":
unittest.main()