| 
312 | 312 |             allow_tf32=ALLOW_TF32,  | 
313 | 313 |         )  | 
314 | 314 | 
  | 
315 |  | -        {% if ki == k_tiles - 1 %}  | 
316 |  | -        # rematerialize rm and rn to save registers  | 
317 |  | -        rcm = rm + tl.arange(0, BLOCK_M)  | 
318 |  | -        rcn = rn + tl.arange(0, BLOCK_N)  | 
319 |  | -        idx_m = rcm[:, None]  | 
320 |  | -        idx_n = rcn[None, :]  | 
321 |  | -        mask = (idx_m < M) & (idx_n < N)  | 
322 |  | -
  | 
323 |  | -        # inductor generates a suffix  | 
324 |  | -        {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}  | 
325 |  | -        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)  | 
326 |  | -        {% endif %}  | 
 | 315 | +        if ki == k_tiles - 1:  | 
 | 316 | +            # rematerialize rm and rn to save registers  | 
 | 317 | +            rcm = rm + tl.arange(0, BLOCK_M)  | 
 | 318 | +            rcn = rn + tl.arange(0, BLOCK_N)  | 
 | 319 | +            idx_m = rcm[:, None]  | 
 | 320 | +            idx_n = rcn[None, :]  | 
 | 321 | +            mask = (idx_m < M) & (idx_n < N)  | 
 | 322 | +
  | 
 | 323 | +            # inductor generates a suffix  | 
 | 324 | +            {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}  | 
 | 325 | +            acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)  | 
 | 326 | +
  | 
327 | 327 | """,  | 
328 | 328 | )  | 
329 | 329 | 
 
  | 
@@ -467,31 +467,30 @@ def apply_scaling(  | 
467 | 467 |         else:  | 
468 | 468 |             accumulator += tl.dot(a, b.T)  | 
469 | 469 | 
  | 
470 |  | -        {% if ki == k_tiles - 1 %}  | 
471 |  | -        # Apply inverse scaling  | 
472 |  | -        offs_cm = offs_am + tl.arange(0, BLOCK_M)  | 
473 |  | -        offs_cn = offs_bn + tl.arange(0, BLOCK_N)  | 
474 |  | -        # Apply scaling  | 
475 |  | -        accumulator = apply_scaling(  | 
476 |  | -            accumulator,  | 
477 |  | -            a_scale,  | 
478 |  | -            b_scale,  | 
479 |  | -            SCALING_ROWWISE,  | 
480 |  | -            offs_cm,  | 
481 |  | -            offs_cn,  | 
482 |  | -            M,  | 
483 |  | -            N,  | 
484 |  | -            stride_a_scale_m,  | 
485 |  | -            stride_b_scale_n,  | 
486 |  | -        )  | 
 | 470 | +        if ki == k_tiles - 1:  | 
 | 471 | +            # Apply inverse scaling  | 
 | 472 | +            offs_cm = offs_am + tl.arange(0, BLOCK_M)  | 
 | 473 | +            offs_cn = offs_bn + tl.arange(0, BLOCK_N)  | 
 | 474 | +            # Apply scaling  | 
 | 475 | +            accumulator = apply_scaling(  | 
 | 476 | +                accumulator,  | 
 | 477 | +                a_scale,  | 
 | 478 | +                b_scale,  | 
 | 479 | +                SCALING_ROWWISE,  | 
 | 480 | +                offs_cm,  | 
 | 481 | +                offs_cn,  | 
 | 482 | +                M,  | 
 | 483 | +                N,  | 
 | 484 | +                stride_a_scale_m,  | 
 | 485 | +                stride_b_scale_n,  | 
 | 486 | +            )  | 
487 | 487 | 
  | 
488 |  | -        idx_m = offs_cm[:, None]  | 
489 |  | -        idx_n = offs_cn[None, :]  | 
490 |  | -        mask = (idx_m < M) & (idx_n < N)  | 
491 |  | -        # inductor generates a suffix  | 
492 |  | -        {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}  | 
493 |  | -        accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)  | 
494 |  | -        {% endif %}  | 
 | 488 | +            idx_m = offs_cm[:, None]  | 
 | 489 | +            idx_n = offs_cn[None, :]  | 
 | 490 | +            mask = (idx_m < M) & (idx_n < N)  | 
 | 491 | +            # inductor generates a suffix  | 
 | 492 | +            {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}  | 
 | 493 | +            accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)  | 
495 | 494 | """  | 
496 | 495 | 
 
  | 
497 | 496 | 
 
  | 
 | 
0 commit comments