Skip to content

Commit 1e92d11

Browse files
authored
[Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352)
* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase. * [Enhancement] Update matmul kernel and optimize argument binding This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code. * lint fix * [Enhancement] Add tensor checks documentation and improve argument binding assertions This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code. * [Enhancement] Update .gitignore and refine matmul kernel for improved performance This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users. * lint fix * lint fix * [Refactor] Simplify tensor_null_test function and remove ptr_null_test This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations. * lint fix * fix
1 parent b8240b7 commit 1e92d11

27 files changed

+1100
-283
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,6 @@ cmake-build-*/
108108

109109
# pre-commit cache
110110
.pre-commit-cache/*
111+
112+
# host checks logs
113+
maint/host_checks/logs/*
Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
# Tensor Checks (Host-Side Auto-Validation)
2+
3+
This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.
4+
5+
## Why Host-Side Checks
6+
- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
7+
- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
8+
- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.
9+
10+
## How To Inspect Host Source
11+
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:
12+
13+
```python
14+
print(matmul_relu_kernel.get_host_source())
15+
```
16+
17+
---
18+
19+
## What The Host Checks
20+
21+
### 1) Argument count and pointer kind
22+
- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message.
23+
- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error.
24+
25+
### 2) Tensor checks (per tensor, after nullability decision)
26+
- Nullability
27+
- If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`.
28+
- If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`.
29+
- Rank (`ndim`)
30+
- Runtime `ndim` must equal the compile-time rank.
31+
- Data type (`dtype`)
32+
- Match the triple `(code, bits, lanes)` with tolerance:
33+
- `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`.
34+
- `float8_e5m2`: accept `e5m2`, `e5m2fnuz`.
35+
- `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match).
36+
- For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped.
37+
- Shape
38+
- Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
39+
- Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
40+
- Strides
41+
- If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality.
42+
- Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`).
43+
- `byte_offset`
44+
- Must be 0 (non-zero raises an error) to keep addressing simple and aligned.
45+
- Device info
46+
- Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.
47+
- When multiple tensors participate, assert that `device_id` matches across them.
48+
- Data pointer
49+
- Must be non-NULL when the tensor is required to be non-null by the nullability rule.
50+
51+
### 3) Scalar checks
52+
- `T.int*` family: require integer; error: `Expect arg[i] to be int`.
53+
- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`.
54+
55+
---
56+
57+
## Shapes and Symbolic Equations: Linear Solving
58+
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:
59+
60+
```python
61+
@T.prim_func
62+
def main(
63+
A: T.Tensor((m,), dtype),
64+
B: T.Tensor((m + n,), dtype),
65+
C: T.Tensor((n * k,), dtype),
66+
):
67+
...
68+
```
69+
70+
This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime.
71+
72+
---
73+
74+
## Nullability Rules and Examples
75+
Which tensors may be NULL?
76+
77+
- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
78+
- Examples:
79+
80+
1) Must be non-NULL (used)
81+
```python
82+
@T.prim_func
83+
def main(A: T.Tensor((M, K), dtype)):
84+
A[0] = 1
85+
```
86+
Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`.
87+
88+
2) Still must be non-NULL (constant-true branch)
89+
```python
90+
some_cond: bool = True
91+
@T.prim_func
92+
def main(A: T.Tensor((M, K), dtype)):
93+
if some_cond:
94+
A[0] = 1
95+
```
96+
97+
3) Nullable (constant-false branch, statically unreachable)
98+
```python
99+
some_cond: bool = False
100+
@T.prim_func
101+
def main(A: T.Tensor((M, K), dtype)):
102+
if some_cond:
103+
A[0] = 1
104+
```
105+
106+
4) Must be non-NULL (runtime condition)
107+
```python
108+
@T.prim_func
109+
def main(A: T.Tensor((M, K), dtype), some_cond: T.bool):
110+
if some_cond:
111+
A[0] = 1
112+
```
113+
Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable.
114+
115+
---
116+
117+
## Device Type Codes (DLPack)
118+
Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`.
119+
Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors.
120+
121+
---
122+
123+
## Common Error Examples (What you’ll see)
124+
- Argument count mismatch (num_args)
125+
- Trigger: missing/extra argument
126+
- Error: `<kernel>: num_args should be N; expected: <num_args>, got: N`
127+
128+
- Pointer-typed argument expected
129+
- Trigger: scalar passed where a tensor is expected
130+
- Error: `<kernel>: Expect arg[i] to be pointer`
131+
132+
- Rank (ndim) mismatch
133+
- Trigger: runtime rank differs from compile-time rank
134+
- Error: `<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim`
135+
136+
- Dtype mismatch
137+
- Trigger: dtype not equal to the compiled dtype and not within the tolerance set
138+
- Error: `<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype`
139+
140+
- Shape constraint violation
141+
- Trigger: a dimension doesn’t match a constant/symbol binding
142+
- Error: `Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>`
143+
144+
- Strides check failed (e.g., non-contiguous layout)
145+
- Trigger: transposed/sliced tensors that violate expected strides
146+
- Error: `Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>`
147+
148+
- Device type mismatch
149+
- Trigger: calling a CUDA kernel with CPU tensors, etc.
150+
- Error: `<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...`
151+
152+
- Device id mismatch
153+
- Trigger: mixing tensors from different GPUs
154+
- Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...`
155+
156+
- NULL data pointer
157+
- Trigger: tensor required to be non-null has a NULL data pointer
158+
- Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`
159+
160+
- Scalar type mismatch
161+
- Trigger: passing float to `T.int32`, or non-boolean to `T.bool`
162+
- Error: `<kernel>: Expect arg[i] to be int/boolean`
163+
164+
---
165+
166+
## Troubleshooting Tips
167+
- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields.
168+
- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.
169+
- Align devices: ensure all participating tensors share the same `device_type` and `device_id`.
170+
- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance.
171+
- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).
172+
173+
---
174+
175+
## FAQ
176+
- Can I disable the checks?
177+
- Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
178+
- Is the overhead noticeable?
179+
- The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.
180+
181+
---
182+
183+
## Reference Example (Matmul + ReLU)
184+
185+
```python
186+
@T.prim_func
187+
def matmul_relu_kernel(
188+
A: T.Tensor((M, K), dtype),
189+
B: T.Tensor((K, N), dtype),
190+
C: T.Tensor((M, N), dtype),
191+
):
192+
# Initialize Kernel Context
193+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
194+
A_shared = T.alloc_shared((block_M, block_K), dtype)
195+
B_shared = T.alloc_shared((block_K, block_N), dtype)
196+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
197+
T.clear(C_local)
198+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
199+
T.copy(A[by * block_M, ko * block_K], A_shared)
200+
T.copy(B[ko * block_K, bx * block_N], B_shared)
201+
T.gemm(A_shared, B_shared, C_local)
202+
T.copy(C_local, C[by * block_M, bx * block_N])
203+
204+
# For debugging, print the host source
205+
print(matmul_relu_kernel.get_host_source())
206+
```
207+
208+
The host will insert all checks described above for this example.
209+
210+
---
211+
212+
## Quick Error Reference (Short List)
213+
- Argument count
214+
- Trigger: missing/extra args; Error: `num_args should be N; expected: <num_args>, got: N`.
215+
- Pointer kind
216+
- Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`.
217+
- Rank (ndim)
218+
- Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`.
219+
- Dtype
220+
- Trigger: mismatch and not tolerated; Error: `dtype ... expected to be <dtype>`.
221+
- Shape
222+
- Trigger: constant/symbol binding violated; Error: `shape[i] ... == <expected>`.
223+
- Strides
224+
- Trigger: layout mismatch; Error: `strides[j] ... == <expected>`.
225+
- Device type
226+
- Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`.
227+
- Device id
228+
- Trigger: tensors on different GPUs; Error: `device_id ... == ...`.
229+
- Data pointer
230+
- Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`.
231+
- Scalar types
232+
- Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`.
233+
234+
---
235+
236+
## Host Error Troubleshooting (Minimal Repros)
237+
238+
Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with:
239+
240+
```python
241+
# Convention:
242+
# A: float16 [M, K]
243+
# B: float16 [K, N]
244+
# C: float16 [M, N]
245+
# Target: CUDA (device_type=2)
246+
fn = matmul_relu_kernel # your compiled function
247+
M = N = K = 1024
248+
```
249+
250+
Adjust dtype/device if your kernel differs.
251+
252+
### 0. Tip: print the host source
253+
```python
254+
print(fn.get_host_source())
255+
```
256+
257+
### 1. num_args mismatch
258+
```python
259+
import torch
260+
261+
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
262+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
263+
# Missing C
264+
fn(A, B)
265+
```
266+
Expected: `<kernel>: num_args should be 3; expected: <num_args>, got: 3`.
267+
268+
Fix: pass all arguments per the signature.
269+
270+
### 2. Expect pointer (tensor) but got scalar
271+
```python
272+
import torch
273+
274+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
275+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
276+
fn(1, B, C)
277+
```
278+
Expected: `<kernel>: Expect arg[0] to be pointer`.
279+
280+
Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor).
281+
282+
### 3. ndim mismatch
283+
```python
284+
import torch
285+
286+
A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3
287+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
288+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
289+
fn(A, B, C)
290+
```
291+
Expected: `<kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim`.
292+
293+
Fix: ensure runtime rank equals compiled rank.
294+
295+
### 4. dtype mismatch
296+
```python
297+
import torch
298+
299+
A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16
300+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
301+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
302+
fn(A, B, C)
303+
```
304+
Expected: `<kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype`.
305+
306+
Fix: `A = A.to(torch.float16)` or create with the correct dtype.
307+
308+
### 5. Shape constant/symbol mismatch
309+
```python
310+
import torch
311+
312+
A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched
313+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
314+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
315+
fn(A, B, C)
316+
```
317+
Expected: `Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>`.
318+
319+
Fix: satisfy linear constraints and constants across tensors.
320+
321+
### 6. Strides check failure (non-contiguous)
322+
```python
323+
import torch
324+
325+
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
326+
A_nc = A.t() # transpose -> non-contiguous
327+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
328+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
329+
fn(A_nc, B, C)
330+
```
331+
Expected: `Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1`.
332+
333+
Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel.
334+
335+
### 7. device_type mismatch
336+
```python
337+
import torch
338+
339+
A = torch.empty((M, K), device='cpu', dtype=torch.float16)
340+
B = torch.empty((K, N), device='cpu', dtype=torch.float16)
341+
C = torch.empty((M, N), device='cpu', dtype=torch.float16)
342+
fn(A, B, C) # CUDA-targeted kernel
343+
```
344+
Expected: `<kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ...`.
345+
346+
Fix: move tensors to the CUDA device.
347+
348+
### 8. device_id mismatch (multi-GPU)
349+
```python
350+
import torch
351+
352+
A = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
353+
B = torch.empty((K, N), device='cuda:1', dtype=torch.float16)
354+
C = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
355+
fn(A, B, C)
356+
```
357+
Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`.
358+
359+
Fix: place all tensors on the same GPU (e.g., `cuda:0`).
360+
361+
### 9. NULL data pointer (advanced)
362+
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this.
363+
364+
Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`.
365+
366+
Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.
367+
368+
### 10. Scalar type mismatch (int / bool)
369+
```python
370+
import tilelang.language as T
371+
372+
@T.prim_func
373+
def scalar_check(x: T.int32, flag: T.bool()):
374+
T.evaluate(0)
375+
376+
scalar_check(1.0, True) # x is float -> Expect arg[0] to be int
377+
scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean
378+
```
379+
380+
Fix: pass correct scalar types, e.g., `scalar_check(1, True)`.
381+
382+
---
383+
384+
## Closing Notes
385+
- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently.
386+
- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.
387+

0 commit comments

Comments
 (0)