Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY-OP][PYTORCH]GroupNorm op support added #5358

Merged
merged 1 commit into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,30 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
}; // struct LayerNormAttrs


/*! \brief Attributes used in group_norm operator */
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
int num_groups;
int axis;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") {
TVM_ATTR_FIELD(num_groups).set_default(0)
.describe("Specify number of groups to separate the channels into.");
TVM_ATTR_FIELD(axis).set_default(1)
.describe("Specify which shape axis denotes the channel.");
TVM_ATTR_FIELD(epsilon).set_default(1e-5)
.describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct GroupNormAttrs


/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
int size;
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,26 @@ def _impl(inputs, input_types):
scale=True)
return _impl


def _group_norm():
def _impl(inputs, input_types):
data = inputs[0]
gamma = inputs[2]
beta = inputs[3]
num_groups = inputs[1]
epsilon = float(inputs[4])

return _op.nn.group_norm(data,
gamma=gamma,
beta=beta,
num_groups=num_groups,
axis=1,
epsilon=epsilon,
center=True,
scale=True)
return _impl


def _transpose(prelude):
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1630,6 +1650,7 @@ def _get_convert_map(prelude):
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::layer_norm" : _layer_norm(),
"aten::group_norm" : _group_norm(),
"aten::transpose" : _transpose(prelude),
"aten::transpose_" : _transpose(prelude),
"aten::t" : _transpose(prelude),
Expand Down
69 changes: 69 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,75 @@ def layer_norm(data,
return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)


def group_norm(data,
gamma,
beta,
num_groups,
axis=1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Group normalization normalizes over group of channels for each training examples.
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
all the channels into a single group, group normalization becomes Layer normalization.
And, when we put each channel into different groups it becomes Instance normalization

https://arxiv.org/pdf/1803.08494.pdf

Applies group normalization to the n-dimensional input array by seperating the input channels
into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
The mean and standard-deviation are calculated separately over the each group. gamma and
beta are learnable per-channel affine transform parameter vectors of size num_channels.

.. math::

out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta

Unlike batch normalization, the mean and var are computed along a group of channels.

If the input has size k on axis 1, then both gamma and beta have shape (k,).

.. note::

This operator can be optimized away for inference.

Parameters
----------
data : tvm.relay.Expr
Input to which group_norm will be applied.

gamma : tvm.relay.Expr
The gamma scale factor.

beta : tvm.relay.Expr
The beta offset factor.

num_groups : int
The number of groups to separate the channels into.

axis : int, optional, default=1
The axis of the channels.

epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.

center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.

scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.

Returns
-------
result : tvm.relay.Expr
The normalized data.
"""
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)


def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
Expand Down
74 changes: 74 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,80 @@ RELAY_REGISTER_OP("nn.layer_norm")
.set_support_level(1)
.add_type_rel("LayerNorm", LayerNormRel);

// group_norm
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);

bool GroupNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const GroupNormAttrs* param = attrs.as<GroupNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorType(data->shape, data->dtype));

return true;
}

Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups,
int axis, double epsilon, bool center, bool scale) {
auto attrs = make_object<GroupNormAttrs>();
attrs->num_groups = num_groups;
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.group_norm");
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
});

RELAY_REGISTER_OP("nn.group_norm")
.describe(R"code(
Group normalization normalizes over group of channels for each training examples.
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
all the channels into a single group, group normalization becomes Layer normalization.
And, when we put each channel into different groups it becomes Instance normalization

https://arxiv.org/pdf/1803.08494.pdf

Applies group normalization to the n-dimensional input array by seperating the input channels
into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
The mean and standard-deviation are calculated separately over the each group. gamma and
beta are learnable per-channel affine transform parameter vectors of size num_channels.

.. math::

out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta

Unlike batch normalization, the mean and var are computed along a group of channels.

If the input has size k on axis 1, then both gamma and beta have shape (k,).

.. note::

This operator can be optimized away for inference.

)code" TVM_ADD_FILELINE)
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which group_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("GroupNorm", GroupNormRel);


// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
int num_inputs,
Expand Down
66 changes: 66 additions & 0 deletions src/relay/transforms/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,66 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
return out;
}


Expr GroupNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<GroupNormAttrs>();
CHECK(param);

int ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
Array<Integer> reduced_axes;
Array<Integer> new_shape;
Array<Integer> old_shape;

int num_groups = param->num_groups;
int channel = ttype->shape[axis].as<IntImmNode>()->value;

// old_shape = N, C, H, W
// new shape = N, num_groups, C/num_groups, H, W
// reduce_axes = axis of (C/num_groups, H, W)
for (int i = 0; i < ndim; ++i) {
auto val = ttype->shape[i].as<IntImmNode>()->value;

// Save the old shape to reshape later
old_shape.push_back(val);
if (i == axis) {
new_shape.push_back(num_groups);
new_shape.push_back(channel / num_groups);
reduced_axes.push_back(i + 1);
continue;
}
if (i >= axis) {
reduced_axes.push_back(i + 1);
}
new_shape.push_back(val);
}

data = Reshape(data, new_shape);

Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {reduced_axes}, true, false);
Expr var = Variance(data, mean, {reduced_axes}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);

out = Reshape(out, old_shape);

if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}

return out;
}

Expr LayerNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expand Down Expand Up @@ -143,6 +203,7 @@ class InferenceSimplifier : public ExprMutator {
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")),
group_norm_op_(Op::Get("nn.group_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}

Expr VisitExpr_(const TupleGetItemNode* n) final {
Expand Down Expand Up @@ -170,6 +231,10 @@ class InferenceSimplifier : public ExprMutator {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == group_norm_op_) {
const auto* call = new_n.as<CallNode>();
return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == instance_norm_op_) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
Expand All @@ -189,6 +254,7 @@ class InferenceSimplifier : public ExprMutator {
const Op& dropout_op_;
const Op& instance_norm_op_;
const Op& layer_norm_op_;
const Op& group_norm_op_;
const Op& l2_norm_op_;
std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
};
Expand Down
23 changes: 23 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,28 @@ def init_weight(m):
init_weight(ln.eval())
verify_model(ln.eval(), input_data=inp)


def test_forward_groupnorm():
input_shape = [10, 6, 5, 5]
input_data = torch.rand(input_shape).float()

# Separate 6 channels into 3 groups
verify_model(torch.nn.GroupNorm(3, 6).eval(), input_data=input_data)

# Put all 6 channels into a single group (equivalent with LayerNorm)
verify_model(torch.nn.GroupNorm(1, 6).eval(), input_data=input_data)

# Separate 6 channels into 6 groups (equivalent with InstanceNorm)
verify_model(torch.nn.GroupNorm(6, 6).eval(), input_data=input_data)

input_shape = [1, 10, 4, 7]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.GroupNorm(1, 10).eval(), input_data=input_data)
verify_model(torch.nn.GroupNorm(2, 10).eval(), input_data=input_data)
verify_model(torch.nn.GroupNorm(5, 10).eval(), input_data=input_data)
verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data)


def test_forward_reshape():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
Expand Down Expand Up @@ -1865,6 +1887,7 @@ def forward(self, *args):
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_layernorm()
test_forward_groupnorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
Expand Down