Skip to content

Commit bda46e3

Browse files
authored
Update golobal to Shared Memory operation
2 parents a843e46 + 1e58249 commit bda46e3

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

csrc/src/flash_attention_fwd_kernel.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
291291
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
292292
Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold);
293293
Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold);
294-
auto tCausalMaskgCausalMask = has_causal_mask ?
295-
gmem_thr_copy_CausalMask.partition_S(gCausalMask) :
296-
make_tensor(static_cast<Element*>(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0));
297-
auto tCausalMasksCausalMask = has_causal_mask ?
298-
gmem_thr_copy_CausalMask.partition_D(sCausalMask) :
299-
make_tensor(static_cast<Element*>(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0));
294+
decltype(gmem_thr_copy_CausalMask.partition_S(gCausalMask)) tCausalMaskgCausalMask;
295+
decltype(gmem_thr_copy_CausalMask.partition_D(sCausalMask)) tCausalMasksCausalMask;
296+
if (has_causal_mask) {
297+
tCausalMaskgCausalMask = gmem_thr_copy_CausalMask.partition_S(gCausalMask);
298+
tCausalMasksCausalMask = gmem_thr_copy_CausalMask.partition_D(sCausalMask);
299+
}
300300

301301
// Matrix Multiply Accumulate
302302
typename Kernel_traits::TiledMma tiled_mma;

0 commit comments

Comments
 (0)