Skip to content
Merged
133 changes: 102 additions & 31 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,14 +790,27 @@ def _mx_dot(inputs, attrs):
def _mx_batch_dot(inputs, attrs):
assert len(inputs) == 2
a, b = inputs
a_shape = _infer_type(a).checked_type.shape
batch_shapes = None
if len(a_shape) > 3:
batch_shapes = a_shape[:-2]
a = _op.reverse_reshape(a, newshape=(-1, 0, 0))
b_shape = _infer_type(b).checked_type.shape
if len(b_shape) > 3:
if batch_shapes is None:
batch_shapes = b_shape[:-2]
b = _op.reverse_reshape(b, newshape=(-1, 0, 0))
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' "is not valid."
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.nn.batch_matmul(a, b)
out = _op.nn.batch_matmul(a, b)
if batch_shapes is not None:
out = _op.reverse_reshape(out, newshape=tuple(batch_shapes) + (0, 0))
return out


def _mx_arange(inputs, attrs):
Expand Down Expand Up @@ -2284,18 +2297,16 @@ def _mx_npi_pad(inputs, attrs):
raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.')
if pad_mode not in ["constant", "edge", "reflect"]:
raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple("pad_width", None)
if pad_width is None:
if "pad_width" not in attrs.attrs:
raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.'
)
# Begin to parse tuple of tuple, we cannot use get_int_tuple here because it's a tuple of tuple.
pad_width = attrs.attrs["pad_width"]
pad_width = pad_width.replace("(", "[")
pad_width = pad_width.replace(")", "]")
pad_width = json.loads(pad_width)
constant_values = attrs.get_float("constant_values", 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))

return _op.nn.pad(
data=inputs[0], pad_width=padding, pad_value=constant_values, pad_mode=pad_mode
data=inputs[0], pad_width=pad_width, pad_value=constant_values, pad_mode=pad_mode
)


Expand All @@ -2311,24 +2322,74 @@ def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
shape_list = list(shape)
new_shape_list = []
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
old_shape = get_const_tuple(_infer_type(inputs[0]).checked_type.shape)
new_shape = []
if reverse:
old_shape = old_shape[::-1]
shape_list = shape_list[::-1]
ptr = 0
unknown_axis = None
src_ptr = 0
while src_ptr < len(shape_list):
ele = shape_list[src_ptr]
src_ptr += 1
if ele > 0:
new_shape.append(ele)
ptr += 1
elif ele == -1:
new_shape.append(-1)
if unknown_axis is not None:
raise tvm.error.OpAttributeInvalid("Can only have one -1 in the input shape.")
unknown_axis = len(new_shape)
ptr += 1
elif ele == -2:
new_shape.append(old_shape[ptr])
ptr += 1
elif ele == -3:
if old_shape[ptr] != 1:
raise tvm.error.OpAttributeInvalid(
"Dimension of the original shape "
"that corresponds to -3 must be 1. Received"
" {}".format(old_shape[ptr])
)
ptr += 1
elif ele == -4:
new_shape += old_shape[ptr:]
break
elif ele == -5:
new_shape.append(old_shape[ptr] * old_shape[ptr + 1])
ptr += 2
elif ele == -6:
# Split axis
lhs = shape_list[src_ptr]
rhs = shape_list[src_ptr + 1]
src_ptr += 2
if lhs == -1 and rhs == -1:
raise tvm.error.OpAttributeInvalid("The lhs and rhs can not both be -1.")
if lhs == -1:
if old_shape[ptr] % rhs != 0:
raise tvm.error.OpAttributeInvalid(
"When splitting the axis, "
"the dimension of the split axis must "
"be divisible by the splitted values."
)
lhs = old_shape[ptr] // rhs
if rhs == -1:
if old_shape[ptr] % lhs != 0:
raise tvm.error.OpAttributeInvalid(
"When splitting the axis, "
"the dimension of the split axis must "
"be divisible by the splitted values."
)
rhs = old_shape[ptr] // lhs
new_shape.append(lhs)
new_shape.append(rhs)
ptr += 1
else:
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
shape = tuple(new_shape_list)
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % ele)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)
new_shape = new_shape[::-1]
return _op.reshape(inputs[0], newshape=new_shape)


def _mx_split_v2(inputs, attrs):
Expand All @@ -2346,12 +2407,21 @@ def _mx_split_v2(inputs, attrs):


def _mx_npi_where_rscalar(inputs, attrs):
cond, dat = inputs
scalar = attrs.get_float("scalar")
dtype = _infer_type(inputs[1]).checked_type.dtype
cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
dat_shape = get_const_tuple(_infer_type(dat).checked_type.shape)
dtype = _infer_type(dat).checked_type.dtype
# Check for broadcasting
out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape
if out_shape != cond_shape:
cond = _op.broadcast_to(cond, out_shape)
if out_shape != dat_shape:
dat = _op.broadcast_to(dat, out_shape)
scalar = _expr.const(scalar, dtype=dtype)
ones = _op.ones_like(inputs[1])
ones = _op.ones_like(dat)
scalar = _op.multiply(ones, scalar)
return _op.where(inputs[0], inputs[1], scalar)
return _op.where(cond, dat, scalar)


# Note: due to attribute conversion constraint
Expand All @@ -2372,13 +2442,13 @@ def _mx_npi_where_rscalar(inputs, attrs):
"reshape_like",
"zeros_like",
"ones_like",
"where",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"where",
]

_convert_map = {
Expand Down Expand Up @@ -2598,6 +2668,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_concatenate": _mx_npi_concatenate,
"_npx_reshape": _mx_npx_reshape,
"_np_copy": _rename(_op.copy),
"_npi_copy": _rename(_op.copy),
"_npi_power": _rename(_op.power),
"_npi_power_scalar": _binop_scalar(_op.power),
"_npi_multiply": _rename(_op.multiply),
Expand All @@ -2606,6 +2677,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_add_scalar": _binop_scalar(_op.add),
"_npi_where_rscalar": _mx_npi_where_rscalar,
"_npi_less": _rename(_op.less),
"_npi_less_equal": _mx_compare(_op.less_equal, _rename),
"_npi_tanh": _rename(_op.tanh),
"_npi_true_divide_scalar": _binop_scalar(_op.divide),
}
Expand Down Expand Up @@ -2717,7 +2789,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
else:
raise RuntimeError("unexpected type %s" % type(res))
node_map[nid] = res

outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the outputs

Returns
-------
output : tvm.te.Tensor
Expand Down Expand Up @@ -135,7 +138,7 @@ def _default_batch_matmul_config(cfg, M, N, K):


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y):
def batch_matmul_cblas(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Expand All @@ -147,6 +150,9 @@ def batch_matmul_cblas(cfg, x, y):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output

Returns
-------
output : tvm.te.Tensor
Expand All @@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
if out_shape is not None:
assert out_shape[0] == XB, "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
return cblas.batch_matmul(x, y, False, True)

Expand Down
30 changes: 20 additions & 10 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,10 @@ def verify(data_shape, axis, use_length, length):
@pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet")
@pytest.mark.parametrize(
"data_shape, pad_width",
[((1, 1, 3, 5), (0, 0, 0, 0, 1, 2, 3, 4)), ((1, 1, 3, 5, 7), (0, 0, 0, 0, 1, 2, 3, 4, 5, 6))],
[
((1, 1, 3, 5), ((0, 0), (0, 0), (1, 2), (3, 4))),
((1, 1, 3, 5, 7), ((0, 0), (0, 0), (1, 2), (3, 4), (5, 6))),
],
)
@pytest.mark.parametrize("mode", ["constant", "edge", "reflect"])
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
Expand All @@ -1925,19 +1928,17 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar
data_np = np.random.uniform(size=data_shape).astype(dtype)
data = mx.sym.var("data")
if mode == "constant":
ref_res = mx.ndarray.pad(
mx.nd.array(data_np), mode=mode, pad_width=pad_width, constant_value=constant_value
)
ref_res = np.pad(data_np, mode=mode, pad_width=pad_width, constant_values=constant_value)
mx_sym = mx.sym.np.pad(
data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value
)
else:
ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode, pad_width=pad_width)
ref_res = np.pad(data_np, mode=mode, pad_width=pad_width)
mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)


@pytest.mark.skipif(
Expand Down Expand Up @@ -2011,8 +2012,12 @@ def test_forward_np_copy(data_shape, dtype, target, ctx, kind):
((2, 3, 8), (-2, -2, 2, -1), False),
((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False),
((8, 3, 3, 3, 4, 4), (-5, -4), False),
((1, 8, 3, 3, 3, 4, 4), (-3, -5, -4), False),
((8, 1, 3, 4), (-2, -3, -1), False),
((8, 3, 3, 3, 3, 8), (-4, -5), True),
((8, 3, 2, 4, 8), (-4, -1, 2, -6), True),
((3, 2, 4, 8, 1, 1), (-4, -1, 2, -6, -5, -3), True),
((2, 4, 1, 8), (-4, -3, -1, 2, -6), True),
],
)
def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind):
Expand Down Expand Up @@ -2099,16 +2104,21 @@ def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind):


@pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet")
@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)])
@pytest.mark.parametrize(
"data_shape,cond_shape",
[[(2, 2, 2), (2, 2, 2)], [(2, 7, 2), (7, 2)], [(2, 2), (1, 2)], [(1, 3), (3, 3)]],
)
@pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("scalar", [1.0, 2.0])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind):
def test_forward_npi_where_rscalar(
data_shape, cond_shape, data_dtype, cond_dtype, scalar, target, ctx, kind
):
if data_dtype == "bool":
scalar = scalar == 0.0
cond_np = np.random.uniform(size=data_shape).astype(cond_dtype)
cond_np = np.random.uniform(size=cond_shape).astype(cond_dtype)
data_np = np.random.uniform(size=data_shape).astype(data_dtype)
cond = mx.sym.var("condition")
data = mx.sym.var("x")
Expand All @@ -2118,7 +2128,7 @@ def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, t
dtypeDic["condition"] = cond_dtype
dtypeDic["x"] = data_dtype
mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic
mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic
)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(cond_np, data_np)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def check_binary_op(opfunc, ref, dtype):
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01, atol=1e-3)

for opfunc, ref in [
(relay.add, np.add),
Expand Down