Skip to content

Commit

Permalink
[TEST] Add test for default padding_option argument in tl.load() (#4376)
Browse files Browse the repository at this point in the history
There was no test coverage for this. I discovered this while
implementing the CPU backend.
  • Loading branch information
int3 authored Jul 25, 2024
1 parent a94dee0 commit 8298561
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions python/test/unit/language/test_block_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option:
block_shape=(BLOCK_SIZE, ), order=(0, ))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
if padding_option is None:
a = tl.load(a_block_ptr, boundary_check=(0, ))
else:
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
tl.store(b_block_ptr, a, boundary_check=(0, ))


Expand All @@ -24,7 +27,7 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option:
for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
("float32", "float32"), ("bfloat16", "bfloat16"))
for n in (64, 128, 256, 512, 1024)
for padding in ("zero", "nan") #
for padding in (None, "zero", "nan") #
])
def test_block_copy(dtypes_str, n, padding_option, device):
src_dtype_str = dtypes_str[0]
Expand All @@ -47,7 +50,7 @@ def test_block_copy(dtypes_str, n, padding_option, device):
assert torch.all(a[0:n // 2] == b[0:n // 2])
if padding_option == "zero":
assert torch.all(b[n // 2:n] == 0)
else:
elif padding_option == "nan":
assert torch.all(torch.isnan(b[n // 2:n]))


Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,7 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
:type other: Block, optional
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
:type boundary_check: tuple of ints, optional
:param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound
:param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see
Expand Down

0 comments on commit 8298561

Please sign in to comment.