Skip to content

Commit

Permalink
enhance test code
Browse files Browse the repository at this point in the history
  • Loading branch information
tiendung committed Nov 19, 2022
1 parent 63989d7 commit a8b23d7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
python3 -m pytest \
tests/test_ndarray.py \
tests/test_nd_backend.py \
# tests/test_conv.py
tests/test_conv.py

# Test for a specific backend
KIM_BACKEND=nd KIM_DEVICE=cuda_triton ./fast_tests.sh
Expand Down
6 changes: 3 additions & 3 deletions tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


_DEVICES = [ nd.cpu_numpy(),
nd.cpu(),
pytest.param(nd.cuda(), marks=pytest.mark.skipif(not nd.cuda().enabled(), reason="No GPU")),
pytest.param(nd.cuda_triton(), marks=pytest.mark.skipif(not nd.cuda_triton().enabled(), reason="No GPU"))
# nd.cpu(),
# pytest.param(nd.cuda(), marks=pytest.mark.skipif(not nd.cuda().enabled(), reason="No GPU")),
# pytest.param(nd.cuda_triton(), marks=pytest.mark.skipif(not nd.cuda_triton().enabled(), reason="No GPU"))
]

def backward_check(f, *args, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def compare_strides(a_np, a_nd):


def check_same_memory(original, view):
# assert original._handle.ptr() == view._handle.ptr()
assert original._handle == view._handle
if hasattr(original, "ptr"):
assert original._handle.ptr() == view._handle.ptr()
else:
assert original._handle == view._handle


# TODO test permute, broadcast_to, reshape, getitem, some combinations thereof
Expand Down Expand Up @@ -142,6 +144,11 @@ def test_setitem_ewise(params, device):
_A[lhs_slices] = _B[rhs_slices]
# end_ptr = A._handle.ptr()
end_ptr = A._handle

if hasattr(start_ptr, "ptr"):
start_ptr = start_ptr.ptr()
end_ptr = end_ptr.ptr()

assert start_ptr == end_ptr, "you should modify in-place"
compare_strides(_A, A)
np.testing.assert_allclose(A.numpy(), _A, atol=1e-5, rtol=1e-5)
Expand Down

0 comments on commit a8b23d7

Please sign in to comment.