Skip to content

Commit b1dec0c

Browse files
siju-samueldpankratz
authored andcommitted
[RELAY][PYTORCH]GroupNorm op support added (apache#5358)
1 parent 457b901 commit b1dec0c

File tree

6 files changed

+277
-0
lines changed

6 files changed

+277
-0
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,30 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
959959
}; // struct LayerNormAttrs
960960

961961

962+
/*! \brief Attributes used in group_norm operator */
963+
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
964+
int num_groups;
965+
int axis;
966+
double epsilon;
967+
bool center;
968+
bool scale;
969+
970+
TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") {
971+
TVM_ATTR_FIELD(num_groups).set_default(0)
972+
.describe("Specify number of groups to separate the channels into.");
973+
TVM_ATTR_FIELD(axis).set_default(1)
974+
.describe("Specify which shape axis denotes the channel.");
975+
TVM_ATTR_FIELD(epsilon).set_default(1e-5)
976+
.describe("Small float added to variance to avoid dividing by zero");
977+
TVM_ATTR_FIELD(center).set_default(true)
978+
.describe("If true, add offset of beta to normalized tensor; "
979+
"otherwise, beta is ignored.");
980+
TVM_ATTR_FIELD(scale).set_default(true)
981+
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
982+
}
983+
}; // struct GroupNormAttrs
984+
985+
962986
/*! \brief Attributes for LRN operator */
963987
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
964988
int size;

python/tvm/relay/frontend/pytorch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,26 @@ def _impl(inputs, input_types):
831831
scale=True)
832832
return _impl
833833

834+
835+
def _group_norm():
836+
def _impl(inputs, input_types):
837+
data = inputs[0]
838+
gamma = inputs[2]
839+
beta = inputs[3]
840+
num_groups = inputs[1]
841+
epsilon = float(inputs[4])
842+
843+
return _op.nn.group_norm(data,
844+
gamma=gamma,
845+
beta=beta,
846+
num_groups=num_groups,
847+
axis=1,
848+
epsilon=epsilon,
849+
center=True,
850+
scale=True)
851+
return _impl
852+
853+
834854
def _transpose(prelude):
835855
def _impl(inputs, input_types):
836856
data = inputs[0]
@@ -1630,6 +1650,7 @@ def _get_convert_map(prelude):
16301650
"aten::batch_norm" : _batch_norm(),
16311651
"aten::instance_norm" : _instance_norm(),
16321652
"aten::layer_norm" : _layer_norm(),
1653+
"aten::group_norm" : _group_norm(),
16331654
"aten::transpose" : _transpose(prelude),
16341655
"aten::transpose_" : _transpose(prelude),
16351656
"aten::t" : _transpose(prelude),

python/tvm/relay/op/nn/nn.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,75 @@ def layer_norm(data,
17081708
return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)
17091709

17101710

1711+
def group_norm(data,
1712+
gamma,
1713+
beta,
1714+
num_groups,
1715+
axis=1,
1716+
epsilon=1e-5,
1717+
center=True,
1718+
scale=True):
1719+
r"""
1720+
Group normalization normalizes over group of channels for each training examples.
1721+
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
1722+
all the channels into a single group, group normalization becomes Layer normalization.
1723+
And, when we put each channel into different groups it becomes Instance normalization
1724+
1725+
https://arxiv.org/pdf/1803.08494.pdf
1726+
1727+
Applies group normalization to the n-dimensional input array by seperating the input channels
1728+
into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
1729+
The mean and standard-deviation are calculated separately over the each group. gamma and
1730+
beta are learnable per-channel affine transform parameter vectors of size num_channels.
1731+
1732+
.. math::
1733+
1734+
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
1735+
* gamma + beta
1736+
1737+
Unlike batch normalization, the mean and var are computed along a group of channels.
1738+
1739+
If the input has size k on axis 1, then both gamma and beta have shape (k,).
1740+
1741+
.. note::
1742+
1743+
This operator can be optimized away for inference.
1744+
1745+
Parameters
1746+
----------
1747+
data : tvm.relay.Expr
1748+
Input to which group_norm will be applied.
1749+
1750+
gamma : tvm.relay.Expr
1751+
The gamma scale factor.
1752+
1753+
beta : tvm.relay.Expr
1754+
The beta offset factor.
1755+
1756+
num_groups : int
1757+
The number of groups to separate the channels into.
1758+
1759+
axis : int, optional, default=1
1760+
The axis of the channels.
1761+
1762+
epsilon : double, optional, default=1e-5
1763+
Small float added to variance to avoid dividing by zero.
1764+
1765+
center : boolean, optional, default=True
1766+
If True, add offset of beta to normalized tensor, If False,
1767+
beta is ignored.
1768+
1769+
scale : boolean, optional, default=True
1770+
If True, multiply by gamma. If False, gamma is not used.
1771+
1772+
Returns
1773+
-------
1774+
result : tvm.relay.Expr
1775+
The normalized data.
1776+
"""
1777+
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)
1778+
1779+
17111780
def batch_matmul(x, y):
17121781
r"""
17131782
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data

src/relay/op/nn/nn.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,80 @@ RELAY_REGISTER_OP("nn.layer_norm")
852852
.set_support_level(1)
853853
.add_type_rel("LayerNorm", LayerNormRel);
854854

855+
// group_norm
856+
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
857+
858+
bool GroupNormRel(const Array<Type>& types,
859+
int num_inputs,
860+
const Attrs& attrs,
861+
const TypeReporter& reporter) {
862+
CHECK_EQ(types.size(), 4);
863+
const auto* data = types[0].as<TensorTypeNode>();
864+
if (data == nullptr) return false;
865+
const GroupNormAttrs* param = attrs.as<GroupNormAttrs>();
866+
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
867+
CHECK(axis >= 0 && axis < (int)data->shape.size());
868+
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
869+
reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
870+
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
871+
872+
return true;
873+
}
874+
875+
Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups,
876+
int axis, double epsilon, bool center, bool scale) {
877+
auto attrs = make_object<GroupNormAttrs>();
878+
attrs->num_groups = num_groups;
879+
attrs->axis = axis;
880+
attrs->epsilon = epsilon;
881+
attrs->center = center;
882+
attrs->scale = scale;
883+
static const Op& op = Op::Get("nn.group_norm");
884+
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
885+
}
886+
887+
TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm")
888+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
889+
runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
890+
});
891+
892+
RELAY_REGISTER_OP("nn.group_norm")
893+
.describe(R"code(
894+
Group normalization normalizes over group of channels for each training examples.
895+
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
896+
all the channels into a single group, group normalization becomes Layer normalization.
897+
And, when we put each channel into different groups it becomes Instance normalization
898+
899+
https://arxiv.org/pdf/1803.08494.pdf
900+
901+
Applies group normalization to the n-dimensional input array by seperating the input channels
902+
into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
903+
The mean and standard-deviation are calculated separately over the each group. gamma and
904+
beta are learnable per-channel affine transform parameter vectors of size num_channels.
905+
906+
.. math::
907+
908+
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
909+
* gamma + beta
910+
911+
Unlike batch normalization, the mean and var are computed along a group of channels.
912+
913+
If the input has size k on axis 1, then both gamma and beta have shape (k,).
914+
915+
.. note::
916+
917+
This operator can be optimized away for inference.
918+
919+
)code" TVM_ADD_FILELINE)
920+
.set_attrs_type<GroupNormAttrs>()
921+
.set_num_inputs(3)
922+
.add_argument("data", "Tensor", "Input to which group_norm will be applied.")
923+
.add_argument("gamma", "Tensor", "The gamma scale factor.")
924+
.add_argument("beta", "Tensor", "The beta offset factor.")
925+
.set_support_level(1)
926+
.add_type_rel("GroupNorm", GroupNormRel);
927+
928+
855929
// relay.nn.batch_matmul
856930
bool BatchMatmulRel(const Array<Type>& types,
857931
int num_inputs,

src/relay/transforms/simplify_inference.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,66 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
6464
return out;
6565
}
6666

67+
68+
Expr GroupNormToInferUnpack(const Attrs attrs,
69+
Expr data,
70+
Expr gamma,
71+
Expr beta,
72+
Type tdata) {
73+
auto ttype = tdata.as<TensorTypeNode>();
74+
CHECK(ttype);
75+
const auto param = attrs.as<GroupNormAttrs>();
76+
CHECK(param);
77+
78+
int ndim = ttype->shape.size();
79+
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
80+
Array<Integer> reduced_axes;
81+
Array<Integer> new_shape;
82+
Array<Integer> old_shape;
83+
84+
int num_groups = param->num_groups;
85+
int channel = ttype->shape[axis].as<IntImmNode>()->value;
86+
87+
// old_shape = N, C, H, W
88+
// new shape = N, num_groups, C/num_groups, H, W
89+
// reduce_axes = axis of (C/num_groups, H, W)
90+
for (int i = 0; i < ndim; ++i) {
91+
auto val = ttype->shape[i].as<IntImmNode>()->value;
92+
93+
// Save the old shape to reshape later
94+
old_shape.push_back(val);
95+
if (i == axis) {
96+
new_shape.push_back(num_groups);
97+
new_shape.push_back(channel / num_groups);
98+
reduced_axes.push_back(i + 1);
99+
continue;
100+
}
101+
if (i >= axis) {
102+
reduced_axes.push_back(i + 1);
103+
}
104+
new_shape.push_back(val);
105+
}
106+
107+
data = Reshape(data, new_shape);
108+
109+
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
110+
Expr mean = Mean(data, {reduced_axes}, true, false);
111+
Expr var = Variance(data, mean, {reduced_axes}, true, false);
112+
Expr denom = Sqrt(Add(var, epsilon));
113+
Expr out = Divide(Subtract(data, mean), denom);
114+
115+
out = Reshape(out, old_shape);
116+
117+
if (param->scale) {
118+
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
119+
}
120+
if (param->center) {
121+
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
122+
}
123+
124+
return out;
125+
}
126+
67127
Expr LayerNormToInferUnpack(const Attrs attrs,
68128
Expr data,
69129
Expr gamma,
@@ -143,6 +203,7 @@ class InferenceSimplifier : public ExprMutator {
143203
dropout_op_(Op::Get("nn.dropout")),
144204
instance_norm_op_(Op::Get("nn.instance_norm")),
145205
layer_norm_op_(Op::Get("nn.layer_norm")),
206+
group_norm_op_(Op::Get("nn.group_norm")),
146207
l2_norm_op_(Op::Get("nn.l2_normalize")) {}
147208

148209
Expr VisitExpr_(const TupleGetItemNode* n) final {
@@ -170,6 +231,10 @@ class InferenceSimplifier : public ExprMutator {
170231
const auto* call = new_n.as<CallNode>();
171232
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
172233
n->args[0]->checked_type());
234+
} else if (n->op == group_norm_op_) {
235+
const auto* call = new_n.as<CallNode>();
236+
return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
237+
n->args[0]->checked_type());
173238
} else if (n->op == instance_norm_op_) {
174239
const auto* call = new_n.as<CallNode>();
175240
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
@@ -189,6 +254,7 @@ class InferenceSimplifier : public ExprMutator {
189254
const Op& dropout_op_;
190255
const Op& instance_norm_op_;
191256
const Op& layer_norm_op_;
257+
const Op& group_norm_op_;
192258
const Op& l2_norm_op_;
193259
std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
194260
};

tests/python/frontend/pytorch/test_forward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,28 @@ def init_weight(m):
717717
init_weight(ln.eval())
718718
verify_model(ln.eval(), input_data=inp)
719719

720+
721+
def test_forward_groupnorm():
722+
input_shape = [10, 6, 5, 5]
723+
input_data = torch.rand(input_shape).float()
724+
725+
# Separate 6 channels into 3 groups
726+
verify_model(torch.nn.GroupNorm(3, 6).eval(), input_data=input_data)
727+
728+
# Put all 6 channels into a single group (equivalent with LayerNorm)
729+
verify_model(torch.nn.GroupNorm(1, 6).eval(), input_data=input_data)
730+
731+
# Separate 6 channels into 6 groups (equivalent with InstanceNorm)
732+
verify_model(torch.nn.GroupNorm(6, 6).eval(), input_data=input_data)
733+
734+
input_shape = [1, 10, 4, 7]
735+
input_data = torch.rand(input_shape).float()
736+
verify_model(torch.nn.GroupNorm(1, 10).eval(), input_data=input_data)
737+
verify_model(torch.nn.GroupNorm(2, 10).eval(), input_data=input_data)
738+
verify_model(torch.nn.GroupNorm(5, 10).eval(), input_data=input_data)
739+
verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data)
740+
741+
720742
def test_forward_reshape():
721743
torch.set_grad_enabled(False)
722744
input_shape = [2, 1, 10, 1, 10]
@@ -1865,6 +1887,7 @@ def forward(self, *args):
18651887
test_forward_batchnorm()
18661888
test_forward_instancenorm()
18671889
test_forward_layernorm()
1890+
test_forward_groupnorm()
18681891
test_forward_transpose()
18691892
test_forward_size()
18701893
test_forward_view()

0 commit comments

Comments
 (0)