@@ -154,15 +154,31 @@ void main() {
154154        }
155155
156156        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
157-             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
158-             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
159-             tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157+             bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
160158
161-             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
159+             if (nem1_bounds_check) {
160+                 tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
161+                 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
162+                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
162163
163-             coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br , Br, j *  Bc, Bc)) ;
164+                 coopmat<float16_t, gl_ScopeWorkgroup , Br, Bc, gl_MatrixUseAccumulator> mv ;
164165
165-             S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
166+                 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
167+ 
168+                 S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
169+             } else {
170+                 tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
171+                 // Don't clamp against nem1 when GQA is enabled
172+                 uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
173+                 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
174+                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
175+ 
176+                 coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
177+ 
178+                 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
179+ 
180+                 S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
181+             }
166182        }
167183
168184        // Clear padding elements to -inf, so they don't contribute to rowmax
0 commit comments