Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,14 +1583,19 @@ def _transform_mask(stride_dim, ellipsis_mask):

# Create final output shape.
final_output = []
final_len = len(fshape_indices)
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
final_len += 1
elif gather_index == -2:
pass
final_len -= 1
else:
final_output.append(out_shape[gather_index])

if final_len == 0:
return _op.squeeze(out, axis=tuple(range(len(fshape_indices))))

if not final_output:
return out
return _op.reshape(out, newshape=tuple(final_output))
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we
often allow `desired` to be close to zero, we generally want non-zero `atol`.
"""
actual = np.asanyarray(actual)
desired = np.asanyarray(desired)
np.testing.assert_allclose(actual.shape, desired.shape)
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)


Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,24 @@ def _test_stridedslice(
def test_forward_stridedslice():
"""test StridedSlice"""
for quantized in [False, True]:
_test_stridedslice(
(1, 3, 3),
[0, 0, 0],
[3, 3, 3],
[1, 1, 1],
"float32",
shrink_axis_mask=7,
quantized=quantized,
)
_test_stridedslice(
(1, 3, 3),
[0, 0, 0],
[3, 3, 3],
[1, 1, 1],
"float32",
shrink_axis_mask=5,
quantized=quantized,
)
_test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1, quantized=quantized)
_test_stridedslice(
(3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32", quantized=quantized
Expand Down
4 changes: 2 additions & 2 deletions tests/python/integration/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_dot():
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
k = te.reduce_axis((0, n), "k")
C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name="C")
C = te.compute((), lambda: te.sum(A[k] * B[k], axis=k), name="C")
s = te.create_schedule(C.op)

def verify(target):
Expand All @@ -36,7 +36,7 @@ def verify(target):
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(nn,)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((1,), dtype=C.dtype), ctx)
c = tvm.nd.array(np.zeros((), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-4)

Expand Down
14 changes: 7 additions & 7 deletions tests/python/integration/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_init_imm():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name="B")
B = te.compute((), lambda: te.sum(A[k], axis=k, init=10.0), name="B")
# schedule
s = te.create_schedule(B.op)
# one line to build the function.
Expand All @@ -86,7 +86,7 @@ def check_target(target="llvm"):
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
b = tvm.nd.array(np.zeros((), dtype=B.dtype), ctx)
fsum(a, b)
res = 10.0 + np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_rfactor():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
B = te.compute((), lambda: te.sum(A[k], axis=k), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
Expand All @@ -145,7 +145,7 @@ def check_target(target="llvm"):
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
b = tvm.nd.array(np.zeros((), dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
Expand Down Expand Up @@ -191,11 +191,11 @@ def test_rfactor_factor_axis():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
B = te.compute((), lambda: te.sum(A[k], axis=k), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf, 1)
BF = s.rfactor(B, kf, 0)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
Expand All @@ -207,7 +207,7 @@ def check_target(target="llvm"):
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
b = tvm.nd.array(np.zeros((), dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
Expand Down