Skip to content

Commit

Permalink
Fixed broken tests as per issue #217 (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjwillemsen authored Oct 5, 2023
1 parent 4ff2fb8 commit fe52050
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
3 changes: 3 additions & 0 deletions kernel_tuner/backends/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,16 @@ def ready_argument_list(self, arguments):
device_ptr = hip.hipMalloc(arg.nbytes)
data_ctypes = arg.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str]))
hip.hipMemcpy_htod(device_ptr, data_ctypes, arg.nbytes)
# may be part of run_kernel, return allocations here instead
ctype_args.append(device_ptr)
else:
raise TypeError("unknown dtype for ndarray")
# Convert valid non-array arguments to ctypes
elif isinstance(arg, np.generic):
data_ctypes = dtype_map[dtype_str](arg)
ctype_args.append(data_ctypes)
else:
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")

return ctype_args

Expand Down
31 changes: 9 additions & 22 deletions test/test_hip_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def env():

return ["vector_add", kernel_string, size, args, tune_params]

# @skip_if_no_pyhip
@pytest.mark.skip("Currently broken due to pull request #216, to be fixed in issue #217")
@skip_if_no_pyhip
def test_ready_argument_list():

size = 1000
Expand All @@ -50,27 +49,16 @@ def test_ready_argument_list():

arguments = [d, a, b, c]

class ArgListStructure(ctypes.Structure):
_fields_ = [("field0", ctypes.POINTER(ctypes.c_float)),
("field1", ctypes.c_int),
("field2", ctypes.POINTER(ctypes.c_float)),
("field3", ctypes.c_bool)]
def __getitem__(self, key):
return getattr(self, self._fields_[key][0])

dev = kt_hip.HipFunctions(0)
gpu_args = dev.ready_argument_list(arguments)

argListStructure = ArgListStructure(d.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.c_int(a),
b.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.c_bool(c))

assert(gpu_args[1] == argListStructure[1])
assert(gpu_args[3] == argListStructure[3])
# ctypes have no equality defined, so indirect comparison for type and value
assert(isinstance(gpu_args[1], ctypes.c_int))
assert(isinstance(gpu_args[3], ctypes.c_bool))
assert(gpu_args[1] == a)
assert(gpu_args[3] == c)

# @skip_if_no_pyhip
@pytest.mark.skip("Currently broken due to pull request #216, to be fixed in issue #217")
@skip_if_no_pyhip
def test_compile():

kernel_string = """
Expand Down Expand Up @@ -119,8 +107,7 @@ def test_memcpy_htod():

assert all(output == x)

# @skip_if_no_pyhip
@pytest.mark.skip("Currently broken due to pull request #216, to be fixed in issue #217")
@skip_if_no_pyhip
def test_copy_constant_memory_args():
kernel_string = """
__constant__ float my_constant_data[100];
Expand Down Expand Up @@ -149,7 +136,7 @@ def test_copy_constant_memory_args():
grid = (1, 1, 1)
dev.run_kernel(kernel, gpu_args, threads, grid)

dev.memcpy_dtoh(output, gpu_args.field0)
dev.memcpy_dtoh(output, gpu_args[0])

assert(my_constant_data == output).all()

Expand Down

0 comments on commit fe52050

Please sign in to comment.