Skip to content

Commit 8fbe1b3

Browse files
authored
[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)
* Add kernel selection option for GEMM v1 in environment settings - Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version. - Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable. - Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable. * bug fix * Add kernel selection option for GEMM in environment settings - Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations. - Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value. - Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable. * Refactor GEMM macro generator to use BufferRegion instead of Buffer - Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`. - Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types. - Simplified buffer access logic for better clarity and maintainability. * Refactor GEMM functions to utilize BufferRegion for improved memory handling - Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices. - Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion. - Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability. - Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access. * Refactor GEMM code for improved readability and consistency - Cleaned up formatting and spacing in GEMM-related files for better readability. - Standardized comments and code structure across various GEMM functions and macros. - Enhanced error messages for clarity in buffer region checks. - Removed redundant lines and improved overall code maintainability. * Update GEMM correctness evaluation and macro generator for improved functionality - Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests. - Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing. - Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access. - Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations. * Refactor GEMM and intrinsic files for improved clarity and functionality - Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code. - Adjusted function signature formatting in `swizzle.py` for better readability. - Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation. - Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness. - Improved function signature formatting in `language.py` for consistency. * Enhance GEMM and MMA functionality for FP64 support - Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection. - Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations. - Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage. - Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms. - Enhanced utility functions to accommodate FP64 index mapping for shared memory operations. * lint fix * Refactor GEMM correctness evaluation and shared memory alignment handling - Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency. - Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope. - Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types. - Cleaned up commented-out test code in `correctness_evaluation.py` for better readability. * Enhance GEMM and MMA implementations with region-based memory handling - Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations. - Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication. - Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations. * Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls - Updated import statements to directly reference modules instead of individual test functions, enhancing clarity. - Modified function calls to use the new module structure for better organization and maintainability in testing examples. * Enhance OnArrayDeclaration method to handle repeated buffer declarations - Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations. - Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations. * Add abbreviation for bfloat16 data type in mfma_macro_generator.py - Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation. * Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation - Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits. - Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity. - Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8. * lint fix * Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP - Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings. - Updated logic to track user-selected backends and conditionally enable defaults based on environment variables. - Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity. - Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency. * Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py - Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity. - Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements. * Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity. * Refactor attention sink examples to simplify index calculations - Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices. - Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops. * Refactor attention sink examples to streamline index calculations - Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices. - Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity. * lint fix * bugfix * Refactor reduce operation handling in CUDA and Python - Removed outdated shared memory reduction logic from `reduce.cc`. - Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes. - Updated CUDA header to define a wider accumulator type for better numerical accuracy. - Enhanced error handling for buffer scope validation in the reduction process. * Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs * Enhance unit loop handling by refining annotation checks - Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present. - Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability. * clean clode
1 parent 2b1f599 commit 8fbe1b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2029
-1319
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,6 @@ cmake-build-*/
105105

106106
# Git version for sdist
107107
.git_commit.txt
108+
109+
# pre-commit cache
110+
.pre-commit-cache/*

CMakeLists.txt

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,50 @@ else()
6565
endif()
6666

6767
# Configs
68-
set(USE_CUDA OFF)
69-
set(USE_ROCM OFF)
70-
set(USE_METAL OFF)
68+
set(TILELANG_BACKENDS CUDA ROCM METAL)
69+
70+
set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)")
71+
set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)")
72+
set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend")
73+
74+
# TVM's config.cmake redefines USE_* options later, so we cache the user's choice
75+
# (including explicit -DUSE_XXX arguments) before we include TVM and restore it
76+
# afterwards.
77+
78+
macro(tilelang_define_backend_option BACKEND)
79+
set(_backend_var "USE_${BACKEND}")
80+
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
81+
set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
82+
83+
set(_user_override OFF)
84+
if(DEFINED ${_user_override_var})
85+
set(_user_override "${${_user_override_var}}")
86+
endif()
87+
88+
if(DEFINED CACHE{${_backend_var}})
89+
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
90+
if(_cache_type STREQUAL "UNINITIALIZED")
91+
set(_user_override ON)
92+
endif()
93+
endif()
94+
95+
set(_default OFF)
96+
if(DEFINED ${_backend_var})
97+
set(_default "${${_backend_var}}")
98+
endif()
99+
100+
option(${_backend_var} "${_doc}" "${_default}")
101+
# Remember if the user explicitly set this option so that later logic
102+
# won't auto-toggle backends they configured on the command line.
103+
set(${_user_override_var} ${_user_override} CACHE INTERNAL
104+
"User explicitly set ${_backend_var} during configuration" FORCE)
105+
set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}")
106+
endmacro()
107+
108+
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
109+
tilelang_define_backend_option(${BACKEND})
110+
endforeach()
111+
71112
set(PREBUILD_CYTHON ON)
72113
# Configs end
73114

@@ -78,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake)
78119
else()
79120
message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.")
80121
endif()
122+
# Re-apply TileLang's preferred backend settings after TVM's config may have
123+
# overridden the USE_* cache entries.
124+
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
125+
set(_backend_var "USE_${BACKEND}")
126+
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
127+
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE)
128+
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}})
129+
endforeach()
81130

82131
# Include directories for TileLang
83132
set(TILE_LANG_INCLUDES ${TVM_INCLUDES})
@@ -95,23 +144,35 @@ file(GLOB TILE_LANG_SRCS
95144
src/target/intrin_rule*.cc
96145
)
97146

98-
# Backend-specific checks and configs
99-
if($ENV{USE_METAL})
100-
set(USE_METAL ON)
101-
elseif(APPLE)
102-
message(STATUS "Enable Metal support by default.")
103-
set(USE_METAL ON)
104-
elseif($ENV{USE_ROCM})
105-
set(USE_ROCM ON)
106-
else()
107-
if($ENV{USE_CUDA})
108-
set(USE_CUDA ON)
109-
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
110-
# Build CPU-only when we explicitly disable CUDA
111-
set(USE_CUDA OFF)
147+
# Track if the user explicitly selected a backend via cache options.
148+
set(TILELANG_BACKEND_USER_SELECTED OFF)
149+
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
150+
set(_backend_var "USE_${BACKEND}")
151+
set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
152+
if(${_backend_var} OR ${_override_var})
153+
set(TILELANG_BACKEND_USER_SELECTED ON)
154+
endif()
155+
endforeach()
156+
157+
# Only auto-select a backend when the user didn't specify one explicitly.
158+
if(NOT TILELANG_BACKEND_USER_SELECTED)
159+
if($ENV{USE_METAL})
160+
set(USE_METAL ON)
161+
elseif(APPLE)
162+
message(STATUS "Enable Metal support by default.")
163+
set(USE_METAL ON)
164+
elseif($ENV{USE_ROCM})
165+
set(USE_ROCM ON)
112166
else()
113-
message(STATUS "Enable CUDA support by default.")
114-
set(USE_CUDA ON)
167+
if($ENV{USE_CUDA})
168+
set(USE_CUDA ON)
169+
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
170+
# Build CPU-only when we explicitly disable CUDA
171+
set(USE_CUDA OFF)
172+
else()
173+
message(STATUS "Enable CUDA support by default.")
174+
set(USE_CUDA ON)
175+
endif()
115176
endif()
116177
endif()
117178

@@ -125,7 +186,7 @@ if(USE_METAL)
125186
elseif(USE_ROCM)
126187
set(CMAKE_HIP_STANDARD 17)
127188
include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake)
128-
find_rocm($ENV{USE_ROCM})
189+
find_rocm(${USE_ROCM})
129190
add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1)
130191

131192
file(GLOB TILE_LANG_HIP_SRCS

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,10 @@ def flash_fwd(
8181
sinks[i] = Sinks[by]
8282

8383
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
84-
start = T.alloc_local([1], 'int32')
85-
if window_size is not None:
86-
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
87-
else:
88-
start[0] = 0
84+
start = T.max(0,
85+
(bx * block_M - window_size) // block_N) if window_size is not None else 0
8986

90-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
87+
for k in T.Pipelined(start, end, num_stages=num_stages):
9188
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
9289
for i, j in T.Parallel(block_M, block_N):
9390
q_idx = bx * block_M + i
@@ -266,14 +263,11 @@ def flash_bwd(
266263
T.clear(dk)
267264

268265
loop_st = T.floordiv(by * block_M, block_N)
269-
loop_ed = T.alloc_local([1], 'int32')
270-
if window_size is not None:
271-
loop_ed[0] = T.min(
272-
T.ceildiv((by + 1) * block_M + window_size, block_N),
273-
T.ceildiv(seq_len, block_N))
274-
else:
275-
loop_ed[0] = T.ceildiv(seq_len, block_N)
276-
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
266+
loop_ed = T.min(
267+
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
268+
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
269+
270+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
277271
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
278272
T.clear(qkT)
279273
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,11 @@ def main(
172172
end = T.min(
173173
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
174174

175-
start = T.alloc_local([1], 'int32')
176-
if window_size is not None:
177-
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
178-
else:
179-
start[0] = 0
175+
start = T.max(0, (bx * block_M + past_len - window_size) //
176+
block_N) if window_size is not None else 0
180177

181178
for k in T.Pipelined(
182-
start[0],
179+
start,
183180
end,
184181
num_stages=num_stages,
185182
order=[-1, 0, 3, 1, -1, 2],

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,10 @@ def flash_fwd(
7878
sinks[i] = Sinks[by]
7979

8080
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
81-
start = T.alloc_local([1], 'int32')
82-
if window_size is not None:
83-
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
84-
else:
85-
start[0] = 0
81+
start = T.max(0,
82+
(bx * block_M - window_size) // block_N) if window_size is not None else 0
8683

87-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
84+
for k in T.Pipelined(start, end, num_stages=num_stages):
8885
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
8986
for i, j in T.Parallel(block_M, block_N):
9087
q_idx = bx * block_M + i
@@ -267,14 +264,10 @@ def flash_bwd(
267264
T.clear(dk)
268265

269266
loop_st = T.floordiv(by * block_M, block_N)
270-
loop_ed = T.alloc_local([1], 'int32')
271-
if window_size is not None:
272-
loop_ed[0] = T.min(
273-
T.ceildiv((by + 1) * block_M + window_size, block_N),
274-
T.ceildiv(seq_len, block_N))
275-
else:
276-
loop_ed[0] = T.ceildiv(seq_len, block_N)
277-
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
267+
loop_ed = T.min(
268+
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
269+
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
270+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
278271
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
279272
T.clear(qkT)
280273
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,10 @@ def main(
162162
end = T.min(
163163
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
164164

165-
start = T.alloc_local([1], 'int32')
166-
if window_size is not None:
167-
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
168-
else:
169-
start[0] = 0
165+
start = T.max(0, (bx * block_M + past_len - window_size) //
166+
block_N) if window_size is not None else 0
170167

171-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
168+
for k in T.Pipelined(start, end, num_stages=num_stages):
172169
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
173170
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
174171
logsum)

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,11 @@ def main(
165165
end = T.min(
166166
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
167167

168-
start = T.alloc_local([1], 'int32')
169-
if window_size is not None:
170-
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
171-
else:
172-
start[0] = 0
168+
start = T.max(0, (bx * block_M + past_len - window_size) //
169+
block_N) if window_size is not None else 0
173170

174171
for k in T.Pipelined(
175-
start[0],
172+
start,
176173
end,
177174
num_stages=num_stages,
178175
order=[-1, 0, 3, 1, -1, 2],

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
11
# ruff: noqa
22
import tilelang.testing
33

4-
from topk_selector import test_topk_selector
5-
from fp8_lighting_indexer import test_fp8_lighting_indexer
6-
from sparse_mla_fwd import test_sparse_mla_fwd
7-
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
8-
from sparse_mla_bwd import test_sparse_mla_bwd
4+
import topk_selector
5+
import fp8_lighting_indexer
6+
import sparse_mla_fwd
7+
import sparse_mla_fwd_pipelined
8+
import sparse_mla_bwd
99

1010

1111
def test_example_topk_selector():
12-
test_topk_selector()
12+
topk_selector.test_topk_selector()
1313

1414

1515
def test_example_fp8_lighting_indexer():
16-
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
16+
fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
1717

1818

1919
@tilelang.testing.requires_cuda
2020
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2121
def test_example_sparse_mla_fwd():
2222
# small shapes for testing
23-
test_sparse_mla_fwd(
23+
sparse_mla_fwd.test_sparse_mla_fwd(
2424
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
2525

2626

2727
@tilelang.testing.requires_cuda
2828
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2929
def test_example_sparse_mla_fwd_pipelined():
3030
# small shapes for testing
31-
test_sparse_mla_fwd_pipelined(
31+
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
3232
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
3333

3434

3535
@tilelang.testing.requires_cuda
3636
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3737
def test_example_sparse_mla_bwd():
38-
test_sparse_mla_bwd(
38+
sparse_mla_bwd.test_sparse_mla_bwd(
3939
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
4040

4141

examples/linear_attention/example_linear_attn_fwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def fused_chunk_linear_attn_fwd(
8080
T.atomic_add(
8181
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
8282
o_shared)
83-
#TODO: consider using vectorized atomic add or tma reduce for sm90
8483

8584
# Output final state
8685
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
@@ -91,6 +90,7 @@ def fused_chunk_linear_attn_fwd(
9190
def tl_fused_chunk_fwd(q, k, v):
9291
B, S, H, D = q.shape
9392
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
93+
print(kernel.get_kernel_source())
9494
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
9595
h = kernel(q, k, v, o)
9696
return o, h

examples/linear_attention/example_retention_fwd.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ def chunk_retention_fwd(
5151
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
5252
T.clear(h)
5353

54-
T.annotate_layout({
55-
q: tl.layout.make_swizzled_layout(q),
56-
k: tl.layout.make_swizzled_layout(k),
57-
v: tl.layout.make_swizzled_layout(v),
58-
h_shared: tl.layout.make_swizzled_layout(h_shared),
59-
s_shared: tl.layout.make_swizzled_layout(s_shared),
60-
})
6154
T.use_swizzle(10)
6255

6356
for i in T.Pipelined(0, NT):

0 commit comments

Comments
 (0)