@@ -338,8 +338,6 @@ def num_tiles_executed(self) -> Int32:
338
338
- Type convert C matrix to output type.
339
339
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
340
340
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
341
- - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
342
- e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
343
341
344
342
SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
345
343
- Read matrix A from SMEM
@@ -635,9 +633,9 @@ def __call__(
635
633
sfb_tensor : cute .Tensor ,
636
634
c_tensor : cute .Tensor ,
637
635
masked_m_tensor : cute .Tensor ,
636
+ alpha_tensor : Optional [cute .Tensor ],
638
637
max_active_clusters : cutlass .Constexpr ,
639
638
stream : cuda .CUstream ,
640
- epilogue_op : cutlass .Constexpr = lambda x : x ,
641
639
):
642
640
"""Execute the GEMM operation in steps:
643
641
- Setup static attributes before smem/grid/tma computation
@@ -662,8 +660,8 @@ def __call__(
662
660
:type max_active_clusters: cutlass.Constexpr
663
661
:param stream: CUDA stream for asynchronous execution
664
662
:type stream: cuda.CUstream
665
- :param epilogue_op : Optional elementwise lambda function to apply to the output tensor
666
- :type epilogue_op: cutlass.Constexpr
663
+ :param alpha_tensor : Optional 1D tensor of shape (l,) containing per-batch scaling factors.
664
+ :type alpha_tensor: cute.Tensor
667
665
:raises TypeError: If input data types are incompatible with the MMA instruction.
668
666
"""
669
667
# Setup static attributes before smem/grid/tma computation
@@ -856,7 +854,6 @@ class SharedStorage:
856
854
857
855
# Launch the kernel synchronously
858
856
self .kernel (
859
- masked_m_tensor , # todo(Yingyi): cleanup?
860
857
tiled_mma ,
861
858
tiled_mma_sfb ,
862
859
tma_atom_a ,
@@ -869,6 +866,7 @@ class SharedStorage:
869
866
tma_tensor_sfb ,
870
867
tma_atom_c ,
871
868
tma_tensor_c ,
869
+ alpha_tensor ,
872
870
self .cluster_layout_vmnk ,
873
871
self .cluster_layout_sfb_vmnk ,
874
872
self .a_smem_layout_staged ,
@@ -878,7 +876,6 @@ class SharedStorage:
878
876
self .c_smem_layout_staged ,
879
877
self .epi_tile ,
880
878
self .tile_sched_params ,
881
- epilogue_op ,
882
879
).launch (
883
880
grid = grid ,
884
881
block = [self .threads_per_cta , 1 , 1 ],
@@ -892,7 +889,6 @@ class SharedStorage:
892
889
@cute .kernel
893
890
def kernel (
894
891
self ,
895
- masked_m : cute .Tensor , # todo(Yingyi): cleanup?
896
892
tiled_mma : cute .TiledMma ,
897
893
tiled_mma_sfb : cute .TiledMma ,
898
894
tma_atom_a : cute .CopyAtom ,
@@ -905,6 +901,7 @@ def kernel(
905
901
mSFB_nkl : cute .Tensor ,
906
902
tma_atom_c : Optional [cute .CopyAtom ],
907
903
mC_mnl : cute .Tensor ,
904
+ alpha : Optional [cute .Tensor ],
908
905
cluster_layout_vmnk : cute .Layout ,
909
906
cluster_layout_sfb_vmnk : cute .Layout ,
910
907
a_smem_layout_staged : cute .ComposedLayout ,
@@ -914,7 +911,6 @@ def kernel(
914
911
c_smem_layout_staged : Union [cute .Layout , cute .ComposedLayout , None ],
915
912
epi_tile : cute .Tile ,
916
913
tile_sched_params : MaskedSchedulerParams ,
917
- epilogue_op : cutlass .Constexpr ,
918
914
):
919
915
"""
920
916
GPU device kernel performing the Persistent batched GEMM computation.
@@ -1616,7 +1612,10 @@ def kernel(
1616
1612
# Convert to C type
1617
1613
#
1618
1614
acc_vec = tiled_copy_r2s .retile (tTR_rAcc ).load ()
1619
- acc_vec = epilogue_op (acc_vec .to (self .c_dtype ))
1615
+ if cutlass .const_expr (alpha is not None ):
1616
+ acc_vec = acc_vec * alpha [work_tile .tile_idx [2 ]]
1617
+
1618
+ acc_vec = acc_vec .to (self .c_dtype )
1620
1619
tRS_rC .store (acc_vec )
1621
1620
1622
1621
#
@@ -2447,6 +2446,7 @@ def run_cute_ptr(
2447
2446
sfb_ptr : cute .Pointer ,
2448
2447
c_ptr : cute .Pointer ,
2449
2448
masked_mptr : cute .Pointer ,
2449
+ alpha_ptr : cute .Pointer ,
2450
2450
current_stream : cuda .CUstream ,
2451
2451
):
2452
2452
a_tensor = cute .make_tensor (
@@ -2522,6 +2522,16 @@ def ceil_div(a, b):
2522
2522
layout = cute .make_ordered_layout ((self ._l ,), order = (0 ,)),
2523
2523
)
2524
2524
2525
+ # Use const_expr for compile-time conditional
2526
+ alpha_tensor = (
2527
+ cute .make_tensor (
2528
+ alpha_ptr ,
2529
+ layout = cute .make_ordered_layout ((self ._l ,), order = (0 ,)),
2530
+ )
2531
+ if cutlass .const_expr (alpha_ptr is not None )
2532
+ else None
2533
+ )
2534
+
2525
2535
Sm100BlockScaledPersistentDenseGemmKernel (
2526
2536
sf_vec_size = self ._sf_vec_size ,
2527
2537
mma_tiler_mn = self ._mma_tiler_mn ,
@@ -2533,6 +2543,7 @@ def ceil_div(a, b):
2533
2543
sfb_tensor ,
2534
2544
c_tensor ,
2535
2545
masked_m_tensor ,
2546
+ alpha_tensor ,
2536
2547
self ._max_active_clusters ,
2537
2548
current_stream ,
2538
2549
)
@@ -2555,6 +2566,8 @@ def run(
2555
2566
sfb_tensor_gpu : torch .Tensor ,
2556
2567
masked_m_tensor_gpu : torch .Tensor ,
2557
2568
c_tensor_gpu : Optional [torch .Tensor ] = None ,
2569
+ alpha_dtype : Optional [torch .dtype ] = None ,
2570
+ alpha_tensor_gpu : Optional [torch .Tensor ] = None ,
2558
2571
sf_vec_size : int = 16 ,
2559
2572
mma_tiler_mn : Tuple [int , int ] = (128 , 128 ),
2560
2573
cluster_shape_mn : Tuple [int , int ] = (1 , 1 ),
@@ -2609,9 +2622,6 @@ def run(
2609
2622
self ._a_major = a_major
2610
2623
self ._b_major = b_major
2611
2624
self ._c_major = c_major
2612
- self ._ab_dtype = ab_dtype
2613
- self ._sf_dtype = sf_dtype
2614
- self ._c_dtype = c_dtype
2615
2625
self ._sf_vec_size = sf_vec_size
2616
2626
self ._mma_tiler_mn = mma_tiler_mn
2617
2627
self ._cluster_shape_mn = cluster_shape_mn
@@ -2648,25 +2658,25 @@ def dtype(cutlass_dtype):
2648
2658
# fp4 gemm output is not supported
2649
2659
c_tensor_gpu = torch .empty (
2650
2660
(self ._l , self ._m , self ._n ),
2651
- dtype = dtype (self . _c_dtype ),
2661
+ dtype = dtype (c_dtype ),
2652
2662
device = "cuda" ,
2653
2663
)
2654
2664
2655
2665
# fp4 or fp8 torch tensor to cute tensor
2656
2666
a_ptr = make_ptr (
2657
- self . _ab_dtype ,
2667
+ ab_dtype ,
2658
2668
a_tensor_gpu .data_ptr (),
2659
2669
cute .AddressSpace .gmem ,
2660
2670
assumed_align = 16 ,
2661
2671
)
2662
2672
b_ptr = make_ptr (
2663
- self . _ab_dtype ,
2673
+ ab_dtype ,
2664
2674
b_tensor_gpu .data_ptr (),
2665
2675
cute .AddressSpace .gmem ,
2666
2676
assumed_align = 16 ,
2667
2677
)
2668
2678
c_ptr = make_ptr (
2669
- self . _c_dtype ,
2679
+ c_dtype ,
2670
2680
c_tensor_gpu .data_ptr (),
2671
2681
cute .AddressSpace .gmem ,
2672
2682
assumed_align = 16 ,
@@ -2678,17 +2688,27 @@ def dtype(cutlass_dtype):
2678
2688
assumed_align = 16 ,
2679
2689
)
2680
2690
sfa_ptr = make_ptr (
2681
- self . _sf_dtype ,
2691
+ sf_dtype ,
2682
2692
sfa_tensor_gpu .data_ptr (),
2683
2693
cute .AddressSpace .gmem ,
2684
2694
assumed_align = 16 ,
2685
2695
)
2686
2696
sfb_ptr = make_ptr (
2687
- self . _sf_dtype ,
2697
+ sf_dtype ,
2688
2698
sfb_tensor_gpu .data_ptr (),
2689
2699
cute .AddressSpace .gmem ,
2690
2700
assumed_align = 16 ,
2691
2701
)
2702
+ alpha_ptr = (
2703
+ make_ptr (
2704
+ alpha_dtype ,
2705
+ alpha_tensor_gpu .data_ptr (),
2706
+ cute .AddressSpace .gmem ,
2707
+ assumed_align = 16 ,
2708
+ )
2709
+ if alpha_tensor_gpu is not None
2710
+ else None
2711
+ )
2692
2712
# todo(Yingyi): might add cute.assume() for shape alignment?
2693
2713
current_stream = cutlass_torch .default_stream ()
2694
2714
@@ -2699,6 +2719,7 @@ def dtype(cutlass_dtype):
2699
2719
sfb_ptr ,
2700
2720
c_ptr ,
2701
2721
masked_m_ptr ,
2722
+ alpha_ptr ,
2702
2723
current_stream ,
2703
2724
)
2704
2725
@@ -2718,7 +2739,7 @@ def grouped_gemm_nt_masked(
2718
2739
** kwargs ,
2719
2740
):
2720
2741
"""
2721
- Executes a masked, batched matrix multiplication (GEMM) with scale factors.
2742
+ Executes a masked, batched matrix multiplication (GEMM) with scale factors and optional alpha scaling at output .
2722
2743
2723
2744
Args:
2724
2745
lhs (Tuple[torch.Tensor, torch.Tensor]): Tuple containing the left-hand side input tensor (A) and its scale factor tensor (SFA).
@@ -2735,6 +2756,8 @@ def grouped_gemm_nt_masked(
2735
2756
sf_vec_size (int): Vector size for scale factors. Typically 16 or 32.
2736
2757
mma_tiler_mn (Tuple[int, int], optional): Shape of the MMA tiler (M, N). Default: (128, 128).
2737
2758
cluster_shape_mn (Tuple[int, int], optional): Shape of the CTA cluster (ClusterM, ClusterN). Default: (1, 1).
2759
+ alpha_dtype (str, optional): Data type for alpha scaling factors.
2760
+ alpha (torch.Tensor, optional): Optional 1D tensor of shape (l,) containing per-batch scaling factors. Perform per-batch scaling out = alpha * out.
2738
2761
2739
2762
Notes:
2740
2763
- Legends of the input tensors:
@@ -2744,6 +2767,7 @@ def grouped_gemm_nt_masked(
2744
2767
* `n32 * n4 * rn` should be same as `N`, which is `n` padded up to the nearest multiple of 128.
2745
2768
* `k4 * rk` should be same as `K`, which is `k / sf_vec_size` padded up to the nearest multiple of 4.
2746
2769
- The function applies masking per batch using masked_m.
2770
+ - If alpha is provided, each batch output is multiplied by its corresponding alpha value. out = alpha * (A @ B).
2747
2771
- The result is written to c_tensor.
2748
2772
"""
2749
2773
@@ -2762,6 +2786,9 @@ def grouped_gemm_nt_masked(
2762
2786
mma_tiler_mn = kwargs .get ("mma_tiler_mm" , (128 , 128 ))
2763
2787
cluster_shape_mn = kwargs .get ("cluster_shape_mm" , (1 , 1 ))
2764
2788
2789
+ alpha = kwargs .get ("alpha" )
2790
+ alpha_dtype = kwargs .get ("alpha_dtype" )
2791
+
2765
2792
# TODO(kaixih@nvidia): do we need `use_cuda_graph`?
2766
2793
wrapper = MaskedBatchedMatmulCuteDSL (use_cuda_graph = False )
2767
2794
wrapper .run (
@@ -2784,4 +2811,6 @@ def grouped_gemm_nt_masked(
2784
2811
sfb_tensor_gpu = sfb_torch ,
2785
2812
c_tensor_gpu = c_torch ,
2786
2813
masked_m_tensor_gpu = masked_m ,
2814
+ alpha_dtype = get_cutlass_dtype (alpha_dtype ),
2815
+ alpha_tensor_gpu = alpha ,
2787
2816
)
0 commit comments