Skip to content

Commit b577710

Browse files
authored
bugfix: fix smem_size in FusedAddRMSNorm which is missed in PR #636 (#646)
Fix smem_size in FusedAddRMSNorm which is missed in #636 Fix issue #645
1 parent 6819a0f commit b577710

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

include/flashinfer/norm.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
207207
const uint32_t num_warps = ceil_div(block_size, 32);
208208
dim3 nblks(batch_size);
209209
dim3 nthrs(32, num_warps);
210-
const uint32_t smem_size = (num_warps + d) * sizeof(float);
210+
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
211211
float weight_bias = 0.f;
212212
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
213213

0 commit comments

Comments
 (0)