From a8b23d7f24c0ba18e8ec6718f5cab1fea281a9b3 Mon Sep 17 00:00:00 2001 From: tiendung Date: Sun, 20 Nov 2022 01:29:37 +0700 Subject: [PATCH] enhance test code --- run.sh | 2 +- tests/test_conv.py | 6 +++--- tests/test_ndarray.py | 11 +++++++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/run.sh b/run.sh index 6958fd1..31091b4 100755 --- a/run.sh +++ b/run.sh @@ -5,7 +5,7 @@ python3 -m pytest \ tests/test_ndarray.py \ tests/test_nd_backend.py \ - # tests/test_conv.py + tests/test_conv.py # Test for a specific backend KIM_BACKEND=nd KIM_DEVICE=cuda_triton ./fast_tests.sh diff --git a/tests/test_conv.py b/tests/test_conv.py index fc946cc..bada5db 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -9,9 +9,9 @@ _DEVICES = [ nd.cpu_numpy(), - nd.cpu(), - pytest.param(nd.cuda(), marks=pytest.mark.skipif(not nd.cuda().enabled(), reason="No GPU")), - pytest.param(nd.cuda_triton(), marks=pytest.mark.skipif(not nd.cuda_triton().enabled(), reason="No GPU")) + # nd.cpu(), + # pytest.param(nd.cuda(), marks=pytest.mark.skipif(not nd.cuda().enabled(), reason="No GPU")), + # pytest.param(nd.cuda_triton(), marks=pytest.mark.skipif(not nd.cuda_triton().enabled(), reason="No GPU")) ] def backward_check(f, *args, **kwargs): diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index cce395e..4dcf797 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -19,8 +19,10 @@ def compare_strides(a_np, a_nd): def check_same_memory(original, view): - # assert original._handle.ptr() == view._handle.ptr() - assert original._handle == view._handle + if hasattr(original, "ptr"): + assert original._handle.ptr() == view._handle.ptr() + else: + assert original._handle == view._handle # TODO test permute, broadcast_to, reshape, getitem, some combinations thereof @@ -142,6 +144,11 @@ def test_setitem_ewise(params, device): _A[lhs_slices] = _B[rhs_slices] # end_ptr = A._handle.ptr() end_ptr = A._handle + + if hasattr(start_ptr, "ptr"): + start_ptr = start_ptr.ptr() + end_ptr = end_ptr.ptr() + assert start_ptr == end_ptr, "you should modify in-place" compare_strides(_A, A) np.testing.assert_allclose(A.numpy(), _A, atol=1e-5, rtol=1e-5)