You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)
* 3rdparty tvm bump
* bump tvm into v0.22.0
* lint fix
* rebase tvm
* Update submodule tvm to latest commit 3085bc4
* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang
* test fix
* add requirement
* atomic_fix
* atomic_fix
* phaseout py39
* optimize
* optimize
* lint fix
* do not clean cache
* do not clean cache
* [Minor] Minor update for Python versions and dependencies
* [Lint] fix lint for py39
* [Lint] fix lint for ROCm
* [Build][CI] Sync CI changes from upstream/sdist
* [Lint] fix lint for ROCm
* [Build][CI] Update `repair-wheel-command`
* [Minor] update abi3audit result format
* [Lint] fix lint for ROCm
* [BugFix] fix build
* [Lint] fix lint for ROCm
* [BugFix] set rpath for libtvm and libtvm_runtime
* [Deps] pin apache-tvm-ffi version
* [Build] set Python 3.9 Limited API for Cython target
* [Build] set Python 3.9 Limited API for Cython target
* [Deps] Restore Python 3.8 support
* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`
* [BugFix] use `;` as delimiter for RPATH on macOS
* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`
* [Build] support `sccache` if available
* [Build] add CIBW import test
* [Build][CI] enable ccache for CIBW on Linux
* [BugFix] set rpath for libtvm and libtvm_runtime
* Revert "[Build][CI] enable ccache for CIBW on Linux"
This reverts commit cd9ab57.
* [CI] fix perfbench bot
* [BugFix] use Python 3.9 to build wheel
* [Minor] update perfbench bot envs
* [BugFix] fix CIBW environment on Linux
* [CI] skip import test on CentOS 7
* [CI] use Python urllib to download file instead of Wget
---------
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
3.**Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
222
+
**Key Differences vs. Basic Example**
223
+
1.**`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
3.**Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
223
226
224
227
---
225
228
@@ -247,7 +250,7 @@ print("Results match!")
247
250
248
251
## Fine-grained MMA Computations
249
252
250
-
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
253
+
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
251
254
252
255
### Example Workflow
253
256
@@ -394,10 +397,10 @@ def tl_matmul(
394
397
]
395
398
```
396
399
397
-
1.**Set Up Tile Sizes and Thread Bindings**
400
+
1.**Set Up Tile Sizes and Thread Bindings**
398
401
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
399
402
400
-
2.**Allocate Warp-local Fragments**
403
+
2.**Allocate Warp-local Fragments**
401
404
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
408
411
409
-
3.**Load Data via `ldmatrix`**
412
+
3.**Load Data via `ldmatrix`**
410
413
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
411
414
```python
412
415
for ki in T.serial(0, (block_K // micro_size_k)):
@@ -418,7 +421,7 @@ def tl_matmul(
418
421
```
419
422
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
420
423
421
-
4.**Perform the MMA Instruction**
424
+
4.**Perform the MMA Instruction**
422
425
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
431
434
432
-
5.**Store Results via `stmatrix`**
435
+
5.**Store Results via `stmatrix`**
433
436
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
434
437
```python
435
438
mma_emitter.stmatrix(C_local, C_shared)
@@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma
444
447
445
448
## References
446
449
447
-
-[NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
448
-
-[NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
450
+
-[NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
451
+
-[NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
449
452
-[PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
0 commit comments