@@ -97,16 +97,18 @@ __device__ __forceinline__ void compute_qk(const typename AttentionVariant::Para
9797 st.m = max (st.m , s[j]);
9898 }
9999
100- float o_scale = math::ptx_exp2 (m_prev - st.m );
101- st.d *= o_scale;
100+ if constexpr (variant.use_softmax ) {
101+ float o_scale = math::ptx_exp2 (m_prev - st.m );
102+ st.d *= o_scale;
102103#pragma unroll
103- for (uint32_t j = 0 ; j < tile_size; ++j) {
104- s[j] = math::ptx_exp2 (s[j] - st.m );
105- st.d += s[j];
106- }
104+ for (uint32_t j = 0 ; j < tile_size; ++j) {
105+ s[j] = math::ptx_exp2 (s[j] - st.m );
106+ st.d += s[j];
107+ }
107108#pragma unroll
108- for (uint32_t i = 0 ; i < vec_size; ++i) {
109- st.o [i] = st.o [i] * o_scale;
109+ for (uint32_t i = 0 ; i < vec_size; ++i) {
110+ st.o [i] = st.o [i] * o_scale;
111+ }
110112 }
111113}
112114
@@ -148,23 +150,38 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s
148150 * \param smem The pointer to shared memory buffer for o
149151 * \param smem_md The pointer to shared memory buffer for m/d
150152 */
151- template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz>
152- __device__ __forceinline__ void sync_state (state_t <vec_size>& st, float * smem, float * smem_md) {
153+ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename AttentionVariant>
154+ __device__ __forceinline__ void sync_state (AttentionVariant variant, state_t <vec_size>& st,
155+ float * smem, float * smem_md) {
153156 if constexpr (bdz > 1 ) {
154157 constexpr uint32_t head_dim = bdx * vec_size;
155158 auto block = cg::this_thread_block ();
156159 uint32_t tx = threadIdx .x , ty = threadIdx .y , tz = threadIdx .z ;
157160 st.o .store (smem + (tz * bdy + ty) * head_dim + tx * vec_size);
158- smem_md[(tz * bdy + ty) * 2 ] = st.m ;
159- smem_md[(tz * bdy + ty) * 2 + 1 ] = st.d ;
160- block.sync ();
161- st.init ();
161+ if constexpr (variant.use_softmax ) {
162+ smem_md[(tz * bdy + ty) * 2 ] = st.m ;
163+ smem_md[(tz * bdy + ty) * 2 + 1 ] = st.d ;
164+ block.sync ();
165+ st.init ();
162166#pragma unroll
163- for (uint32_t j = 0 ; j < bdz; ++j) {
164- float mz = smem_md[(j * bdy + ty) * 2 ], dz = smem_md[(j * bdy + ty) * 2 + 1 ];
165- vec_t <float , vec_size> oz;
166- oz.load (smem + (j * bdy + ty) * head_dim + tx * vec_size);
167- st.merge (oz, mz, dz);
167+ for (uint32_t j = 0 ; j < bdz; ++j) {
168+ float mz = smem_md[(j * bdy + ty) * 2 ], dz = smem_md[(j * bdy + ty) * 2 + 1 ];
169+ vec_t <float , vec_size> oz;
170+ oz.load (smem + (j * bdy + ty) * head_dim + tx * vec_size);
171+ st.merge (oz, mz, dz);
172+ }
173+ } else {
174+ block.sync ();
175+ st.init ();
176+ #pragma unroll
177+ for (uint32_t j = 0 ; j < bdz; ++j) {
178+ vec_t <float , vec_size> oz;
179+ oz.load (smem + (j * bdy + ty) * head_dim + tx * vec_size);
180+ #pragma unroll
181+ for (uint32_t i = 0 ; i < vec_size; ++i) {
182+ st.o [i] += oz[i];
183+ }
184+ }
168185 }
169186 }
170187}
@@ -338,8 +355,10 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__
338355 block.sync ();
339356
340357 // sync local state of all warps inside a threadblock
341- sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast <float *>(smem), smem_md);
342- st_local.normalize ();
358+ sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast <float *>(smem), smem_md);
359+ if constexpr (variant.use_softmax ) {
360+ st_local.normalize ();
361+ }
343362
344363 st_local.o .cast_store (o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
345364 if (lse != nullptr ) {
@@ -557,8 +576,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
557576 block.sync ();
558577
559578 // sync local state of all warps inside a threadblock
560- sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast <float *>(smem), smem_md);
561- st.normalize ();
579+ sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast <float *>(smem), smem_md);
580+ if constexpr (variant.use_softmax ) {
581+ st.normalize ();
582+ }
562583
563584 if (tz == 0 ) {
564585 st.o .cast_store (o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
0 commit comments