Skip to content

Commit 3fda5c0

Browse files
committed
Merge branch 'main' into sk-build-core
2 parents be422e7 + 7fb0677 commit 3fda5c0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2003
-408
lines changed

.github/workflows/metal_ci.yml

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
name: CI Test on Metal
2+
on: [pull_request]
3+
4+
env:
5+
PYTHON_VERSION: '3.12'
6+
VENV_DIR: tilelang_ci
7+
8+
jobs:
9+
format-check:
10+
runs-on: [macos-latest]
11+
12+
permissions:
13+
contents: write
14+
15+
steps:
16+
- name: Checkout repository
17+
uses: actions/checkout@v4
18+
with:
19+
fetch-depth: 0
20+
submodules: recursive
21+
22+
- name: Install python via uv
23+
uses: astral-sh/setup-uv@v6
24+
with:
25+
enable-cache: true
26+
ignore-nothing-to-cache: true
27+
activate-environment: true
28+
python-version: ${{ env.PYTHON_VERSION }}
29+
30+
- name: Ensure venv (local & persistent)
31+
run: |
32+
[[ -f requirements-test.txt ]] && \
33+
uv pip install -r requirements-test.txt --no-build-isolation
34+
35+
- name: Run format check
36+
run: |
37+
set -ex
38+
mkdir -p build
39+
# run cmake to create the build directory with compile_commands.json
40+
cd build; cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DUSE_METAL=ON; cd ..
41+
if ! output=$(./format.sh 2>&1); then
42+
echo "------------------------------------"
43+
echo "message:"
44+
echo "$output"
45+
printf '%s\n' "$output"
46+
echo "------------------------------------"
47+
exit 1
48+
fi
49+
50+
build-test-metal:
51+
runs-on: [macos-latest]
52+
needs: format-check
53+
permissions:
54+
contents: read
55+
env:
56+
CMAKE_C_COMPILER_LAUNCHER: ccache
57+
CMAKE_CXX_COMPILER_LAUNCHER: ccache
58+
steps:
59+
- name: Checkout repository
60+
uses: actions/checkout@v4
61+
with:
62+
fetch-depth: 1
63+
submodules: recursive
64+
65+
- name: ccache
66+
uses: hendrikmuhs/ccache-action@v1.2
67+
with:
68+
create-symlink: true
69+
key: ${{ github.job }}-${{ matrix.os }}
70+
71+
- name: Install python via uv
72+
uses: astral-sh/setup-uv@v6
73+
with:
74+
enable-cache: true
75+
ignore-nothing-to-cache: true
76+
activate-environment: true
77+
python-version: ${{ env.PYTHON_VERSION }}
78+
79+
- name: Ensure venv (local & persistent)
80+
run: uv pip install -r requirements-test.txt -r requirements-build.txt
81+
82+
- name: Build wheel
83+
run: |
84+
source .venv/bin/activate
85+
uv pip install -v --no-build-isolation .
86+
87+
- name: Run metal test
88+
run: |
89+
cd testing/python
90+
unset PYTHONPATH
91+
python -m pytest -k metal -v -r fE --durations=0 --timeout=3600

CMakeLists.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ if(NOT TVM_FOUND)
2626
endif()
2727
endif()
2828

29+
<<<<<<< HEAD
2930
# Backend-specific checks and configs
3031
if(APPLE)
3132
message(STATUS "Enable Metal support by default.")
@@ -41,6 +42,29 @@ elseif($ENV{USE_ROCM})
4142
add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1)
4243

4344
include_directories(SYSTEM ${ROCM_INCLUDE_DIRS})
45+
=======
46+
# Handle TVM prebuild or build TVM from source
47+
if(DEFINED TVM_PREBUILD_PATH)
48+
message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}")
49+
add_library(tvm SHARED IMPORTED)
50+
find_library(TVM_LIBRARY_LOCATION
51+
NAMES tvm
52+
HINTS "${TVM_PREBUILD_PATH}"
53+
)
54+
set_target_properties(tvm PROPERTIES
55+
IMPORTED_LOCATION "${TVM_LIBRARY_LOCATION}"
56+
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
57+
)
58+
add_library(tvm_runtime SHARED IMPORTED)
59+
find_library(TVM_RUNTIME_LIBRARY_LOCATION
60+
NAMES tvm_runtime
61+
HINTS "${TVM_PREBUILD_PATH}"
62+
)
63+
set_target_properties(tvm_runtime PROPERTIES
64+
IMPORTED_LOCATION "${TVM_RUNTIME_LIBRARY_LOCATION}"
65+
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
66+
)
67+
>>>>>>> main
4468
else()
4569
if($ENV{USE_CUDA})
4670
set(USE_CUDA ON)

CMakeLists.txt.bak

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,21 @@ endif()
108108
if(DEFINED TVM_PREBUILD_PATH)
109109
message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}")
110110
add_library(tvm SHARED IMPORTED)
111+
find_library(TVM_LIBRARY_LOCATION
112+
NAMES tvm
113+
HINTS "${TVM_PREBUILD_PATH}"
114+
)
111115
set_target_properties(tvm PROPERTIES
112-
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm.so"
116+
IMPORTED_LOCATION "${TVM_LIBRARY_LOCATION}"
113117
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
114118
)
115119
add_library(tvm_runtime SHARED IMPORTED)
120+
find_library(TVM_RUNTIME_LIBRARY_LOCATION
121+
NAMES tvm_runtime
122+
HINTS "${TVM_PREBUILD_PATH}"
123+
)
116124
set_target_properties(tvm_runtime PROPERTIES
117-
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm_runtime.so"
125+
IMPORTED_LOCATION "${TVM_RUNTIME_LIBRARY_LOCATION}"
118126
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
119127
)
120128
else()
@@ -157,6 +165,13 @@ if(USE_ROCM)
157165
list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS})
158166
endif()
159167

168+
if(USE_METAL)
169+
tilelang_file_glob(GLOB TILE_LANG_METAL_SRCS
170+
src/target/rt_mod_metal.cc
171+
)
172+
list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})
173+
endif()
174+
160175
message(STATUS "Collected source files: ${TILE_LANG_SRCS}")
161176

162177
# Add TileLang object library
@@ -221,6 +236,9 @@ target_compile_definitions(tilelang_objs PRIVATE -DTILE_LANG_EXPORTS)
221236
# Shared library
222237
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
223238
target_link_libraries(tilelang PUBLIC tvm_runtime)
239+
if(USE_METAL)
240+
target_link_libraries(tilelang PUBLIC tvm)
241+
endif()
224242

225243
# Static library
226244
add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
1313
<img src=./images/MatmulExample.png />
1414

1515
## Latest News
16+
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
1617
- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
1718
Check out the preview here:
1819
🔗 [link](https://github.com/tile-ai/tilelang-ascend).

examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
3838
v += (bos * H + i_h) * V
3939
block_indices += (bos + i_t) * H * S + i_h * S
4040

41-
# if USE_BLOCK_COUNTS:
42-
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
43-
# else:
4441
NS = S
4542

4643
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
@@ -452,7 +449,12 @@ def get_configs():
452449

453450

454451
@tilelang.autotune(configs=get_configs(),)
455-
@tilelang.jit
452+
@tilelang.jit(
453+
pass_configs={
454+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
455+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
456+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
457+
})
456458
def tilelang_sparse_attention(batch,
457459
heads,
458460
seq_len,

examples/deepseek_nsa/example_tilelang_nsa_bwd.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
import tilelang
1818

1919

20-
@tilelang.jit(pass_configs={
21-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
22-
})
20+
@tilelang.jit(
21+
pass_configs={
22+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
23+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
24+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
25+
})
2326
def tilelang_kernel_fwd(
2427
batch,
2528
heads,

examples/deepseek_nsa/example_tilelang_nsa_fwd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
@tilelang.jit(
12-
out_idx=[-1], pass_configs={
12+
out_idx=[-1],
13+
pass_configs={
1314
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
15+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
16+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
1417
})
1518
def native_sparse_attention(batch,
1619
heads,

examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from einops import rearrange
1717

1818

19-
@tilelang.jit(pass_configs={
20-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
21-
})
19+
@tilelang.jit(
20+
pass_configs={
21+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
22+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
23+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
24+
})
2225
def native_sparse_attention_varlen(batch,
2326
heads,
2427
c_seq_len,

examples/deepseek_v32/README.md

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ deepseek_v32/
66
├── figures/ # Figures and diagrams
77
├── inference/ # Inference implementation folder
88
├── fp8_lighting_indexer.py # FP8 lighting indexer
9+
├── sparse_mla_bwd.py # Sparse MLA backward implementation
910
├── sparse_mla_fwd.py # Sparse MLA forward implementation
1011
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
1112
├── topk_selector.py # Top-k selector implementation
@@ -21,7 +22,7 @@ The architecture diagram above highlights three key components (shown in green)
2122

2223
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2324
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
24-
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass
25+
3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes
2526

2627
### Lightning Indexer
2728

@@ -166,3 +167,57 @@ for i_i in T.serial(T.ceildiv(NI, 2)):
166167
```
167168

168169
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
170+
171+
### Sparse MLA Backward
172+
173+
The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity.
174+
175+
The backward pass consists of three main stages:
176+
177+
**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient):
178+
179+
```python
180+
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
181+
T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
182+
T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do)
183+
for i, j in T.Parallel(block_ND, block_ND):
184+
acc[i, j] += o[i, j] * do[i, j]
185+
T.reduce_sum(acc, delta, 1)
186+
```
187+
188+
**2. Main Backward Computation**: Computes gradients through sparse attention:
189+
190+
```python
191+
# Sparse MLA backward: iterate over selected indices only
192+
for i_i in T.Pipelined(NI, num_stages=num_stages):
193+
# Load KV data for selected indices
194+
for bi_i, d_i in T.Parallel(BI, D):
195+
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i]
196+
197+
# Recompute attention scores for backward
198+
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
199+
200+
# Apply softmax gradient: dP = P * (dP_raw - Delta)
201+
for h_i, bi_i in T.Parallel(padded_H, BI):
202+
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
203+
```
204+
205+
The key gradient computations are:
206+
- **dQ = dP @ K** (query gradients)
207+
- **dK = dP^T @ Q** (key gradients)
208+
- **dV = P^T @ dO** (value gradients)
209+
210+
**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation:
211+
212+
```python
213+
# Atomically update dKV at selected indices
214+
for bi_i, d_i in T.Parallel(BI // split_store, D // 4):
215+
T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4],
216+
acc_dkv_shared[bi_i, d_i * 4])
217+
```
218+
219+
**Performance**: The sparse MLA backward achieves excellent performance:
220+
- **H800 SXM**: ~100 TFlops
221+
- **H200 SXM**: ~115 TFlops
222+
223+
The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization.

0 commit comments

Comments
 (0)