Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix topi bop overloading #3595

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
# specific language governing permissions and limitations
# under the License.

# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# https://github.com/dmlc/tvm/issues/3240.
# TODO: restore the file name after this issue is resolved.
import os
import copy
import numpy as np
Expand All @@ -31,7 +26,7 @@
from tvm.autotvm.task import ConfigEntity
from tvm.autotvm.measure import MeasureResult, MeasureInput
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
from test_graph_tuner_utils import create_workload
from test_autotvm_graph_tuner_utils import create_workload


def _create_data(target, dshape, dtype, layout):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
# specific language governing permissions and limitations
# under the License.

# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# https://github.com/dmlc/tvm/issues/3240
# TODO: restore the file name after this issue is resolved.
import tvm

from tvm import autotvm, relay
Expand Down
14 changes: 7 additions & 7 deletions tests/python/unittest/test_lang_tensor_overload_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def test_operator_type_and_tags():

assert isinstance(k + n, tvm.expr.Expr)
assert isinstance(n + n, tvm.expr.Expr)
assert isinstance(k + A, tvm.tensor.Tensor)
assert isinstance(A + k, tvm.tensor.Tensor)
assert isinstance(n + A, tvm.tensor.Tensor)
assert isinstance(A + n, tvm.tensor.Tensor)
assert isinstance(A + A, tvm.tensor.Tensor)
assert isinstance(k + A, tvm.expr.Expr)
assert isinstance(A + k, tvm.expr.Expr)
assert isinstance(n + A, tvm.expr.Expr)
assert isinstance(A + n, tvm.expr.Expr)
assert isinstance(A + A, tvm.expr.Expr)

assert isinstance(k + B, tvm.tensor.Tensor)
assert isinstance(B + k, tvm.tensor.Tensor)
Expand All @@ -58,8 +58,8 @@ def test_operator_type_and_tags():
assert isinstance(n + B2, tvm.expr.Expr)
assert isinstance(B2 + n, tvm.expr.Expr)
assert isinstance(B2 + B2, tvm.expr.Expr)
assert isinstance(B2 + A, tvm.tensor.Tensor)
assert isinstance(A + B2, tvm.tensor.Tensor)
assert isinstance(B2 + A, tvm.expr.Expr)
assert isinstance(A + B2, tvm.expr.Expr)
assert isinstance(B2 + B, tvm.tensor.Tensor)
assert isinstance(B + B2, tvm.tensor.Tensor)

Expand Down
15 changes: 14 additions & 1 deletion topi/python/topi/generic_op_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
from . import math as _math


def _is_non_zero_rank_tensor(tensor):
"""Check whether input is a non-zero-rank tensor.
"""
if isinstance(tensor, tvm.tensor.Tensor) and tensor.shape:
return True
return False

def _make_bop(broadcast_bop, orig_bop):
"""Make a specific overloaded binary operator of Tensor when applicable;
apply the original operator if it is not supposed to be overloaded.
Expand Down Expand Up @@ -63,6 +70,11 @@ def _tensor_bop_impl(lhs, rhs):
scalar like type (e.g., numeric types, Expr, or TensorSlice),
it performs tensor-scalar {op} operation on an element-wise basis.

If one operand is TensorSlice, while the other operand is zero-rank
Tensor, it performs default generic.{op} operation to return an expr.
This is to avoid error when reduce op with such patten of bop appears
inside the body of lambda function of tvm.compute.

Otherwise, it performs default generic.{op} operation, as defined
in tvm.generic module.

Expand All @@ -79,8 +91,9 @@ def _tensor_bop_impl(lhs, rhs):
tvm.Expr (otherwise)
The result of {op} operation.
"""
if not isinstance(lhs, tvm.tensor.Tensor) and not isinstance(rhs, tvm.tensor.Tensor):
if not _is_non_zero_rank_tensor(lhs) and not _is_non_zero_rank_tensor(rhs):
return orig_bop(lhs, rhs)

return broadcast_bop(lhs, rhs)
_tensor_bop_impl.__doc__ = _tensor_bop_impl.__doc__.format(op=name)
return _tensor_bop_impl
Expand Down