Skip to content

Commit e44f6c0

Browse files
authored
[ONNX] Add Einsum converter (#8985)
* einsum * address review * move files around * use generic topi op * TODO comment * jostle ci * jostle ci
1 parent 2aebd33 commit e44f6c0

File tree

12 files changed

+251
-6
lines changed

12 files changed

+251
-6
lines changed

include/tvm/relay/attrs/transform.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
475475
.describe("The first element is not included")
476476
.set_default(Bool(false));
477477
}
478-
};
478+
}; // struct ScanopAttrs
479479

480480
/*! \brief Attributes used in unique operator */
481481
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
@@ -489,6 +489,15 @@ struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
489489
}
490490
}; // struct UniqueAttrs
491491

492+
/*! \brief Attributes used in einsum operator */
493+
struct EinsumAttrs : public tvm::AttrsNode<EinsumAttrs> {
494+
String equation;
495+
496+
TVM_DECLARE_ATTRS(EinsumAttrs, "relay.attrs.EinsumAttrs") {
497+
TVM_ATTR_FIELD(equation).describe("The einsum expression string");
498+
}
499+
}; // struct EinsumAttrs
500+
492501
} // namespace relay
493502
} // namespace tvm
494503
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_

python/tvm/relay/frontend/onnx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3501,6 +3501,15 @@ def _impl_v11(cls, inputs, attr, params):
35013501
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)
35023502

35033503

3504+
class Einsum(OnnxOpConverter):
3505+
"""Operator converter for Einsum"""
3506+
3507+
@classmethod
3508+
def _impl_v12(cls, inputs, attr, params):
3509+
equation = attr["equation"].decode("utf-8")
3510+
return _op.einsum(inputs, equation)
3511+
3512+
35043513
class RandomUniform(OnnxOpConverter):
35053514
"""Operator converter for random_uniform"""
35063515

@@ -3864,6 +3873,7 @@ def _get_convert_map(opset):
38643873
"Range": Range.get_converter(opset),
38653874
"CumSum": CumSum.get_converter(opset),
38663875
"Unique": Unique.get_converter(opset),
3876+
"Einsum": Einsum.get_converter(opset),
38673877
# defs/control_flow
38683878
"Loop": Loop.get_converter(opset),
38693879
"If": If.get_converter(opset),

python/tvm/relay/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from . import _transform
5555
from . import _reduce
5656
from . import _algorithm
57+
from . import _math
5758

5859

5960
def _register_op_make():

python/tvm/relay/op/_math.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Backend compiler related feature registration"""
18+
from . import op as _reg
19+
from . import strategy
20+
21+
# einsum
22+
_reg.register_strategy("einsum", strategy.einsum_strategy)

python/tvm/relay/op/_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def compute_unique(attrs, inputs, output_type):
182182
_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy)
183183
_reg.register_shape_func("invert_permutation", False, elemwise_shape_func)
184184

185+
185186
#####################
186187
# Shape functions #
187188
#####################

python/tvm/relay/op/strategy/cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,3 +1215,16 @@ def invert_permutation_strategy_cuda(attrs, inputs, out_type, target):
12151215
name="invert_permutation.cuda",
12161216
)
12171217
return strategy
1218+
1219+
1220+
@einsum_strategy.register(["cuda", "gpu"])
1221+
def einsum_strategy_cuda(attrs, inputs, out_type, target):
1222+
"""einsum cuda strategy"""
1223+
strategy = _op.OpStrategy()
1224+
# TODO: Add cuda-specific op implementation for einsum
1225+
strategy.add_implementation(
1226+
wrap_compute_einsum(topi.einsum),
1227+
wrap_topi_schedule(topi.generic.schedule_extern),
1228+
name="einsum.cuda",
1229+
)
1230+
return strategy

python/tvm/relay/op/strategy/generic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,3 +1669,24 @@ def invert_permutation_strategy(attrs, inputs, out_type, target):
16691669
name="invert_permutation.generic",
16701670
)
16711671
return strategy
1672+
1673+
1674+
def wrap_compute_einsum(topi_compute):
1675+
"""Wrap einsum topi compute"""
1676+
1677+
def _compute_einsum(attrs, inputs, _):
1678+
return [topi_compute(attrs.equation, *inputs)]
1679+
1680+
return _compute_einsum
1681+
1682+
1683+
@override_native_generic_func("einsum_strategy")
1684+
def einsum_strategy(attrs, inputs, out_type, target):
1685+
"""einsum generic strategy"""
1686+
strategy = _op.OpStrategy()
1687+
strategy.add_implementation(
1688+
wrap_compute_einsum(topi.einsum),
1689+
wrap_topi_schedule(topi.generic.schedule_einsum),
1690+
name="einsum.generic",
1691+
)
1692+
return strategy

python/tvm/relay/op/tensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,29 @@ def concatenate(data, axis):
11041104
return _make.concatenate(Tuple(data), axis)
11051105

11061106

1107+
def einsum(data, equation):
1108+
"""Evaluates the Einstein summation convention on data
1109+
1110+
Parameters
1111+
----------
1112+
data : Union(List[relay.Expr], Tuple[relay.Expr])
1113+
A list of tensors.
1114+
equation : str
1115+
The einsum expression string.
1116+
1117+
Returns
1118+
-------
1119+
result : relay.Expr
1120+
The output tensor from the einsum op.
1121+
"""
1122+
data = list(data)
1123+
if not data:
1124+
raise ValueError("relay.einsum requires data to be non-empty.")
1125+
if not isinstance(equation, str):
1126+
raise ValueError("einsum `equation` must be a str")
1127+
return _make.einsum(Tuple(data), equation)
1128+
1129+
11071130
def stack(data, axis):
11081131
"""Join a sequence of arrays along a new axis.
11091132

python/tvm/topi/generic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@
3939
from .sort import *
4040
from .search import *
4141
from .image import *
42+
from .math import *

python/tvm/topi/generic/math.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Generic math operators"""
18+
from .default import default_schedule as _default_schedule
19+
20+
21+
def schedule_einsum(outs):
22+
"""Schedule for einsum operator.
23+
24+
Parameters
25+
----------
26+
outs: Array of Tensor
27+
The computation graph description of einsum.
28+
29+
Returns
30+
-------
31+
s: Schedule
32+
The computation schedule for the op.
33+
"""
34+
return _default_schedule(outs, False)

0 commit comments

Comments
 (0)