Skip to content

Commit 63bc166

Browse files
committed
[Enhancement] Add stride index validation in CythonKernelWrapper
* Introduced an assertion to ensure that the stride index is within the valid range of tensor dimensions in `cython_wrapper.pyx`. * This change prevents potential out-of-bounds errors when accessing tensor dimensions, enhancing the robustness of the code.
1 parent 72be490 commit 63bc166

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ cdef class CythonKernelWrapper:
123123
# otherwise, maybe torch.data_ptr() for T.ptr inputs
124124
continue
125125
for stride_idx, expected_stride in strides_list:
126+
# Ensure the stride index is within the valid range of tensor dimensions
127+
# (stride_idx should be less than the number of dimensions of the tensor)
128+
assert stride_idx < tensor.dim(), f"Stride index {stride_idx} out of bounds for tensor with {tensor.dim()} dimensions"
129+
if tensor.shape[stride_idx] == 1:
130+
continue
126131
actual_stride = tensor.stride(stride_idx)
127132
if actual_stride != expected_stride:
128133
raise ValueError(

0 commit comments

Comments
 (0)