Skip to content

Commit

Permalink
[LANG] Comparison operators support for Imm expressions (apache#3283)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Jun 4, 2019
1 parent 072f8cc commit befd8c1
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 12 deletions.
10 changes: 10 additions & 0 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ def __init__(self, value):
self.__init_handle_by_constructor__(
_make.StringImm, value)

def __eq__(self, other):
if isinstance(other, ConstExpr):
return self.value == other.value
return self.value == other

def __ne__(self, other):
if isinstance(other, ConstExpr):
return self.value != other.value
return self.value != other


@register_node
class Cast(Expr):
Expand Down
13 changes: 5 additions & 8 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
Expand Down Expand Up @@ -301,8 +300,6 @@ def optimize(func, params=None):
"FoldConstant",
"CanonicalizeOps"]

cfg = _transform.build_config(required_pass=opt_passes)

if params:
name_dict = {}
for arg in func.params:
Expand All @@ -321,25 +318,25 @@ def optimize(func, params=None):
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)

if "SimplifyInference" in cfg.required_pass:
if "SimplifyInference" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)

if "FoldConstant" in cfg.required_pass:
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)

if "FoldScaleAxis" in cfg.required_pass:
if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)

if "CanonicalizeOps" in cfg.required_pass:
if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)

if "FoldConstant" in cfg.required_pass:
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)

return func
Expand Down
8 changes: 4 additions & 4 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ bool BroadcastRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
<< ",Out:" << types[2] << std::endl;
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
Expand All @@ -126,8 +126,8 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
<< ",Out:" << types[2] << std::endl;
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def test_equality():
d = (c != c)
assert not d


def test_equality_string_imm():
x = 'a'
y = tvm.make.StringImm(x)
x == y.value
x == y


if __name__ == "__main__":
test_cast()
test_attr()
Expand All @@ -178,3 +186,4 @@ def test_equality():
test_all()
test_bitwise()
test_equality()
test_equality_string_imm()
7 changes: 7 additions & 0 deletions tests/python/unittest/test_lang_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@ def test_map_save_load_json():
assert(dd == {"a": 2, "b": 3})


def test_in_container():
arr = tvm.convert(['a', 'b', 'c'])
assert 'a' in arr
assert tvm.make.StringImm('a') in arr
assert 'd' not in arr

if __name__ == "__main__":
test_str_map()
test_array()
test_map()
test_array_save_load_json()
test_map_save_load_json()
test_in_container()

0 comments on commit befd8c1

Please sign in to comment.