Skip to content

Commit aa0b109

Browse files
authored
[Language] Support atomic add with ret (#870)
* Add atomic operations for CUDA templates in new atomic.h file - Introduced atomic functions including AtomicMax, AtomicMin, AtomicAdd, and their return variants for various data types. - Implemented support for half, bfloat16, and float types with appropriate memory ordering. - Moved atomic-related utilities from common.h to the new atomic.h file for better organization. - Added Python bindings for atomic operations in tilelang, including atomic_max, atomic_min, atomic_add, and their vectorized counterparts. - Updated customize.py to utilize the new atomic functions, enhancing modularity and maintainability. * Refactor atomic operations in CUDA templates for improved readability - Reformatted atomic operation implementations in atomic.h for better code clarity. - Adjusted function signatures in tilelang's atomic.py to enhance readability by aligning parameters. - Cleaned up unnecessary whitespace and comments in customize.py to streamline the codebase. * Add thread storage synchronization configuration option - Introduced a new configuration option `tl.disable_thread_storage_sync` to control the automatic insertion of thread synchronization barriers in shared memory access. - Updated the `ThreadSync` pass to check this configuration and bypass synchronization if disabled. - Enhanced documentation in `builtin.h` and `pass_config.py` to clarify the purpose and usage of the new option. * Refactor thread storage sync configuration retrieval - Simplified the retrieval of the thread storage sync configuration in the `ThreadSync` pass by removing unnecessary intermediate variables. - Ensured that the inclusion of `builtin.h` is consistent by moving it to the appropriate location in the file. * test fix * Update atomic operations and tests for improved functionality - Updated atomic operations in CUDA templates to remove unnecessary address_of calls, enhancing performance and readability. - Refactored atomic operation signatures in tilelang's atomic.py to accept references instead of pointers. - Added new atomic operations and corresponding test cases for atomic add, max, min, and load/store functionalities in the testing suite. - Updated the TVM subproject to the latest commit for better compatibility. * Update attention sink examples to use 32 heads - Modified the `heads` parameter in both `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py` and `example_mha_sink_fwd_bhsd_wgmma_pipelined.py` from 1 to 32 to enhance performance in attention mechanisms. - Ensured consistency across example scripts for improved usability and testing. * Refactor atomic add handling in vectorization - Simplified the extraction of buffer loads for atomic add operations by removing unnecessary address_of calls, improving code clarity and performance. - Updated the data type retrieval for vectorization size calculation to directly access the buffer load node, enhancing efficiency. * Add loop break functionality and enhance thread synchronization - Introduced a new `loop_break` function in `customize.py` to allow breaking out of loops, returning a call to the `tl.loop_break` intrinsic. - Updated the `sync_threads` function in `builtin.py` to accept optional parameters for `barrier_id` and `arrive_count`, improving its flexibility for thread synchronization. - Added necessary imports in `__init__.py` to include the new `loop_break` function for broader accessibility. * test fix
1 parent 1dfac2e commit aa0b109

17 files changed

+992
-359
lines changed

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,9 @@ def gen_inputs(B, H, Sq, Skv, D,
366366

367367
def main(
368368
batch: int = 1,
369-
heads: int = 64,
370-
seq_q: int = 4096,
371-
seq_kv: int = 4096,
369+
heads: int = 32,
370+
seq_q: int = 256,
371+
seq_kv: int = 256,
372372
dim: int = 128,
373373
groups: int = 8,
374374
window_size: int | None = None,

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens
229229
return query, key, value, sinks
230230

231231

232-
def main(batch: int = 8,
233-
heads: int = 32,
234-
seq_q: int = 4096,
235-
seq_kv: int = 4096,
232+
def main(batch: int = 1,
233+
heads: int = 1,
234+
seq_q: int = 256,
235+
seq_kv: int = 256,
236236
dim: int = 128,
237237
window_size: int | None = None,
238238
tune: bool = False):

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,10 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens
354354
return query, key, value, sinks
355355

356356

357-
def main(batch: int = 8,
357+
def main(batch: int = 1,
358358
heads: int = 32,
359-
seq_q: int = 4096,
360-
seq_kv: int = 4096,
359+
seq_q: int = 256,
360+
seq_kv: int = 256,
361361
dim: int = 128,
362362
window_size: int | None = None,
363363
tune: bool = False):

src/op/atomic_add.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
293293
if (dst_predicate.defined())
294294
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
295295

296-
Call address_of_value =
297-
tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});
298-
299-
new_args.push_back(address_of_value);
296+
new_args.push_back(dst_value);
300297
new_args.push_back(src_value);
301298

302299
Call atomicadd_call =

src/op/builtin.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool);
2020
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
2121
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
2222
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
23+
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
2324
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
2425
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
2526
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);

src/op/builtin.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
5555
static constexpr const char *kDisableDynamicTailSplit =
5656
"tl.disable_dynamic_tail_split";
5757

58+
/*!
59+
* \brief Whether to disable thread storage synchronization
60+
*
61+
* When enabled, disables the automatic insertion of thread synchronization
62+
* barriers (e.g., __syncthreads()) for shared memory access coordination.
63+
* This can be useful for performance optimization in cases where manual
64+
* synchronization is preferred or when synchronization is not needed.
65+
*
66+
* kDisableThreadStorageSync = "tl.disable_thread_storage_sync"
67+
*
68+
*/
69+
static constexpr const char *kDisableThreadStorageSync =
70+
"tl.disable_thread_storage_sync";
71+
5872
/*!
5973
* \brief The size of the vectorized dimension in buffer, designed by user
6074
*

src/tl_templates/cuda/atomic.h

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#pragma once
2+
3+
#ifndef __CUDACC_RTC__
4+
#include <cuda_runtime.h>
5+
#endif
6+
7+
#include <cuda/atomic>
8+
#include <cutlass/numeric_types.h>
9+
10+
using cutlass::bfloat16_t;
11+
using cutlass::half_t;
12+
13+
#define TL_DEVICE __forceinline__ __device__
14+
15+
template <typename T> struct normalize_atomic_type {
16+
using type = T;
17+
};
18+
19+
template <> struct normalize_atomic_type<half_t> {
20+
using type = half;
21+
};
22+
23+
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
24+
template <> struct normalize_atomic_type<bfloat16_t> {
25+
using type = __nv_bfloat16;
26+
};
27+
#endif
28+
29+
template <typename T1, typename T2> TL_DEVICE T1 cuda_cast(T2 val) {
30+
return T1(val);
31+
}
32+
33+
template <> TL_DEVICE half cuda_cast<half, float>(float val) {
34+
return __float2half(val);
35+
}
36+
37+
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
38+
template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
39+
return __float2bfloat16(val);
40+
}
41+
#endif
42+
43+
template <typename T1, typename T2>
44+
TL_DEVICE void AtomicMax(T1 &ref, T2 val,
45+
int memory_order = int(cuda::memory_order_relaxed)) {
46+
using NT1 = typename normalize_atomic_type<T1>::type;
47+
T1 *address = &ref;
48+
if constexpr (std::is_same_v<NT1, half> ||
49+
std::is_same_v<NT1, __nv_bfloat16>) {
50+
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
51+
} else {
52+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
53+
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
54+
}
55+
}
56+
57+
template <typename T1, typename T2>
58+
TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
59+
int memory_order = int(cuda::memory_order_relaxed)) {
60+
using NT1 = typename normalize_atomic_type<T1>::type;
61+
T1 *address = &ref;
62+
if constexpr (std::is_same_v<NT1, half> ||
63+
std::is_same_v<NT1, __nv_bfloat16>) {
64+
return static_cast<T1>(
65+
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
66+
} else {
67+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
68+
return static_cast<T1>(
69+
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
70+
}
71+
}
72+
73+
template <typename T1, typename T2>
74+
TL_DEVICE void AtomicMin(T1 &ref, T2 val,
75+
int memory_order = int(cuda::memory_order_relaxed)) {
76+
using NT1 = typename normalize_atomic_type<T1>::type;
77+
T1 *address = &ref;
78+
if constexpr (std::is_same_v<NT1, half> ||
79+
std::is_same_v<NT1, __nv_bfloat16>) {
80+
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
81+
} else {
82+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
83+
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
84+
}
85+
}
86+
87+
template <typename T1, typename T2>
88+
TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
89+
int memory_order = int(cuda::memory_order_relaxed)) {
90+
using NT1 = typename normalize_atomic_type<T1>::type;
91+
T1 *address = &ref;
92+
if constexpr (std::is_same_v<NT1, half> ||
93+
std::is_same_v<NT1, __nv_bfloat16>) {
94+
return static_cast<T1>(
95+
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
96+
} else {
97+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
98+
return static_cast<T1>(
99+
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
100+
}
101+
}
102+
103+
template <typename T1, typename T2>
104+
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
105+
int memory_order = int(cuda::memory_order_relaxed)) {
106+
using NT1 = typename normalize_atomic_type<T1>::type;
107+
T1 *address = &ref;
108+
if constexpr (std::is_same_v<NT1, half> ||
109+
std::is_same_v<NT1, __nv_bfloat16>) {
110+
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
111+
} else {
112+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
113+
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
114+
}
115+
}
116+
117+
template <typename T1, typename T2>
118+
TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
119+
int memory_order = int(cuda::memory_order_relaxed)) {
120+
using NT1 = typename normalize_atomic_type<T1>::type;
121+
T1 *address = &ref;
122+
if constexpr (std::is_same_v<NT1, half> ||
123+
std::is_same_v<NT1, __nv_bfloat16>) {
124+
return static_cast<T1>(
125+
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
126+
} else {
127+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
128+
return static_cast<T1>(
129+
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
130+
}
131+
}
132+
133+
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val) {
134+
atomicAdd(reinterpret_cast<half2 *>(ref),
135+
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
136+
}
137+
138+
TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val) {
139+
return atomicAdd(reinterpret_cast<half2 *>(ref),
140+
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
141+
}
142+
143+
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
144+
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val) {
145+
atomicAdd(
146+
reinterpret_cast<__nv_bfloat162 *>(ref),
147+
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
148+
}
149+
150+
TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) {
151+
return atomicAdd(
152+
reinterpret_cast<__nv_bfloat162 *>(ref),
153+
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
154+
}
155+
#endif
156+
157+
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
158+
TL_DEVICE void AtomicAddx2(float *ref, float *val) {
159+
atomicAdd(reinterpret_cast<float2 *>(ref),
160+
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
161+
}
162+
163+
TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val) {
164+
return atomicAdd(reinterpret_cast<float2 *>(ref),
165+
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
166+
}
167+
168+
TL_DEVICE void AtomicAddx4(float *ref, float *val) {
169+
atomicAdd(reinterpret_cast<float4 *>(ref),
170+
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
171+
}
172+
173+
TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val) {
174+
return atomicAdd(reinterpret_cast<float4 *>(ref),
175+
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
176+
}
177+
#endif
178+
179+
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
180+
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
181+
return aref.load(cuda::memory_order(memory_order));
182+
}
183+
184+
template <typename T1, typename T2>
185+
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
186+
using NT1 = typename normalize_atomic_type<T1>::type;
187+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
188+
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
189+
}

0 commit comments

Comments
 (0)