Skip to content

Commit 51278b7

Browse files
yyihuangyzh119
andauthored
feat: scaling at fp4 gemm epilogue (#1498)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description out = out * scale WIP ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent 6518ce4 commit 51278b7

File tree

2 files changed

+61
-20
lines changed

2 files changed

+61
-20
lines changed

β€Žflashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,6 @@ def num_tiles_executed(self) -> Int32:
338338
- Type convert C matrix to output type.
339339
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
340340
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
341-
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
342-
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
343341
344342
SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
345343
- Read matrix A from SMEM
@@ -635,9 +633,9 @@ def __call__(
635633
sfb_tensor: cute.Tensor,
636634
c_tensor: cute.Tensor,
637635
masked_m_tensor: cute.Tensor,
636+
alpha_tensor: Optional[cute.Tensor],
638637
max_active_clusters: cutlass.Constexpr,
639638
stream: cuda.CUstream,
640-
epilogue_op: cutlass.Constexpr = lambda x: x,
641639
):
642640
"""Execute the GEMM operation in steps:
643641
- Setup static attributes before smem/grid/tma computation
@@ -662,8 +660,8 @@ def __call__(
662660
:type max_active_clusters: cutlass.Constexpr
663661
:param stream: CUDA stream for asynchronous execution
664662
:type stream: cuda.CUstream
665-
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
666-
:type epilogue_op: cutlass.Constexpr
663+
:param alpha_tensor: Optional 1D tensor of shape (l,) containing per-batch scaling factors.
664+
:type alpha_tensor: cute.Tensor
667665
:raises TypeError: If input data types are incompatible with the MMA instruction.
668666
"""
669667
# Setup static attributes before smem/grid/tma computation
@@ -856,7 +854,6 @@ class SharedStorage:
856854

857855
# Launch the kernel synchronously
858856
self.kernel(
859-
masked_m_tensor, # todo(Yingyi): cleanup?
860857
tiled_mma,
861858
tiled_mma_sfb,
862859
tma_atom_a,
@@ -869,6 +866,7 @@ class SharedStorage:
869866
tma_tensor_sfb,
870867
tma_atom_c,
871868
tma_tensor_c,
869+
alpha_tensor,
872870
self.cluster_layout_vmnk,
873871
self.cluster_layout_sfb_vmnk,
874872
self.a_smem_layout_staged,
@@ -878,7 +876,6 @@ class SharedStorage:
878876
self.c_smem_layout_staged,
879877
self.epi_tile,
880878
self.tile_sched_params,
881-
epilogue_op,
882879
).launch(
883880
grid=grid,
884881
block=[self.threads_per_cta, 1, 1],
@@ -892,7 +889,6 @@ class SharedStorage:
892889
@cute.kernel
893890
def kernel(
894891
self,
895-
masked_m: cute.Tensor, # todo(Yingyi): cleanup?
896892
tiled_mma: cute.TiledMma,
897893
tiled_mma_sfb: cute.TiledMma,
898894
tma_atom_a: cute.CopyAtom,
@@ -905,6 +901,7 @@ def kernel(
905901
mSFB_nkl: cute.Tensor,
906902
tma_atom_c: Optional[cute.CopyAtom],
907903
mC_mnl: cute.Tensor,
904+
alpha: Optional[cute.Tensor],
908905
cluster_layout_vmnk: cute.Layout,
909906
cluster_layout_sfb_vmnk: cute.Layout,
910907
a_smem_layout_staged: cute.ComposedLayout,
@@ -914,7 +911,6 @@ def kernel(
914911
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
915912
epi_tile: cute.Tile,
916913
tile_sched_params: MaskedSchedulerParams,
917-
epilogue_op: cutlass.Constexpr,
918914
):
919915
"""
920916
GPU device kernel performing the Persistent batched GEMM computation.
@@ -1616,7 +1612,10 @@ def kernel(
16161612
# Convert to C type
16171613
#
16181614
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
1619-
acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
1615+
if cutlass.const_expr(alpha is not None):
1616+
acc_vec = acc_vec * alpha[work_tile.tile_idx[2]]
1617+
1618+
acc_vec = acc_vec.to(self.c_dtype)
16201619
tRS_rC.store(acc_vec)
16211620

16221621
#
@@ -2447,6 +2446,7 @@ def run_cute_ptr(
24472446
sfb_ptr: cute.Pointer,
24482447
c_ptr: cute.Pointer,
24492448
masked_mptr: cute.Pointer,
2449+
alpha_ptr: cute.Pointer,
24502450
current_stream: cuda.CUstream,
24512451
):
24522452
a_tensor = cute.make_tensor(
@@ -2522,6 +2522,16 @@ def ceil_div(a, b):
25222522
layout=cute.make_ordered_layout((self._l,), order=(0,)),
25232523
)
25242524

2525+
# Use const_expr for compile-time conditional
2526+
alpha_tensor = (
2527+
cute.make_tensor(
2528+
alpha_ptr,
2529+
layout=cute.make_ordered_layout((self._l,), order=(0,)),
2530+
)
2531+
if cutlass.const_expr(alpha_ptr is not None)
2532+
else None
2533+
)
2534+
25252535
Sm100BlockScaledPersistentDenseGemmKernel(
25262536
sf_vec_size=self._sf_vec_size,
25272537
mma_tiler_mn=self._mma_tiler_mn,
@@ -2533,6 +2543,7 @@ def ceil_div(a, b):
25332543
sfb_tensor,
25342544
c_tensor,
25352545
masked_m_tensor,
2546+
alpha_tensor,
25362547
self._max_active_clusters,
25372548
current_stream,
25382549
)
@@ -2555,6 +2566,8 @@ def run(
25552566
sfb_tensor_gpu: torch.Tensor,
25562567
masked_m_tensor_gpu: torch.Tensor,
25572568
c_tensor_gpu: Optional[torch.Tensor] = None,
2569+
alpha_dtype: Optional[torch.dtype] = None,
2570+
alpha_tensor_gpu: Optional[torch.Tensor] = None,
25582571
sf_vec_size: int = 16,
25592572
mma_tiler_mn: Tuple[int, int] = (128, 128),
25602573
cluster_shape_mn: Tuple[int, int] = (1, 1),
@@ -2609,9 +2622,6 @@ def run(
26092622
self._a_major = a_major
26102623
self._b_major = b_major
26112624
self._c_major = c_major
2612-
self._ab_dtype = ab_dtype
2613-
self._sf_dtype = sf_dtype
2614-
self._c_dtype = c_dtype
26152625
self._sf_vec_size = sf_vec_size
26162626
self._mma_tiler_mn = mma_tiler_mn
26172627
self._cluster_shape_mn = cluster_shape_mn
@@ -2648,25 +2658,25 @@ def dtype(cutlass_dtype):
26482658
# fp4 gemm output is not supported
26492659
c_tensor_gpu = torch.empty(
26502660
(self._l, self._m, self._n),
2651-
dtype=dtype(self._c_dtype),
2661+
dtype=dtype(c_dtype),
26522662
device="cuda",
26532663
)
26542664

26552665
# fp4 or fp8 torch tensor to cute tensor
26562666
a_ptr = make_ptr(
2657-
self._ab_dtype,
2667+
ab_dtype,
26582668
a_tensor_gpu.data_ptr(),
26592669
cute.AddressSpace.gmem,
26602670
assumed_align=16,
26612671
)
26622672
b_ptr = make_ptr(
2663-
self._ab_dtype,
2673+
ab_dtype,
26642674
b_tensor_gpu.data_ptr(),
26652675
cute.AddressSpace.gmem,
26662676
assumed_align=16,
26672677
)
26682678
c_ptr = make_ptr(
2669-
self._c_dtype,
2679+
c_dtype,
26702680
c_tensor_gpu.data_ptr(),
26712681
cute.AddressSpace.gmem,
26722682
assumed_align=16,
@@ -2678,17 +2688,27 @@ def dtype(cutlass_dtype):
26782688
assumed_align=16,
26792689
)
26802690
sfa_ptr = make_ptr(
2681-
self._sf_dtype,
2691+
sf_dtype,
26822692
sfa_tensor_gpu.data_ptr(),
26832693
cute.AddressSpace.gmem,
26842694
assumed_align=16,
26852695
)
26862696
sfb_ptr = make_ptr(
2687-
self._sf_dtype,
2697+
sf_dtype,
26882698
sfb_tensor_gpu.data_ptr(),
26892699
cute.AddressSpace.gmem,
26902700
assumed_align=16,
26912701
)
2702+
alpha_ptr = (
2703+
make_ptr(
2704+
alpha_dtype,
2705+
alpha_tensor_gpu.data_ptr(),
2706+
cute.AddressSpace.gmem,
2707+
assumed_align=16,
2708+
)
2709+
if alpha_tensor_gpu is not None
2710+
else None
2711+
)
26922712
# todo(Yingyi): might add cute.assume() for shape alignment?
26932713
current_stream = cutlass_torch.default_stream()
26942714

@@ -2699,6 +2719,7 @@ def dtype(cutlass_dtype):
26992719
sfb_ptr,
27002720
c_ptr,
27012721
masked_m_ptr,
2722+
alpha_ptr,
27022723
current_stream,
27032724
)
27042725

@@ -2718,7 +2739,7 @@ def grouped_gemm_nt_masked(
27182739
**kwargs,
27192740
):
27202741
"""
2721-
Executes a masked, batched matrix multiplication (GEMM) with scale factors.
2742+
Executes a masked, batched matrix multiplication (GEMM) with scale factors and optional alpha scaling at output.
27222743
27232744
Args:
27242745
lhs (Tuple[torch.Tensor, torch.Tensor]): Tuple containing the left-hand side input tensor (A) and its scale factor tensor (SFA).
@@ -2735,6 +2756,8 @@ def grouped_gemm_nt_masked(
27352756
sf_vec_size (int): Vector size for scale factors. Typically 16 or 32.
27362757
mma_tiler_mn (Tuple[int, int], optional): Shape of the MMA tiler (M, N). Default: (128, 128).
27372758
cluster_shape_mn (Tuple[int, int], optional): Shape of the CTA cluster (ClusterM, ClusterN). Default: (1, 1).
2759+
alpha_dtype (str, optional): Data type for alpha scaling factors.
2760+
alpha (torch.Tensor, optional): Optional 1D tensor of shape (l,) containing per-batch scaling factors. Perform per-batch scaling out = alpha * out.
27382761
27392762
Notes:
27402763
- Legends of the input tensors:
@@ -2744,6 +2767,7 @@ def grouped_gemm_nt_masked(
27442767
* `n32 * n4 * rn` should be same as `N`, which is `n` padded up to the nearest multiple of 128.
27452768
* `k4 * rk` should be same as `K`, which is `k / sf_vec_size` padded up to the nearest multiple of 4.
27462769
- The function applies masking per batch using masked_m.
2770+
- If alpha is provided, each batch output is multiplied by its corresponding alpha value. out = alpha * (A @ B).
27472771
- The result is written to c_tensor.
27482772
"""
27492773

@@ -2762,6 +2786,9 @@ def grouped_gemm_nt_masked(
27622786
mma_tiler_mn = kwargs.get("mma_tiler_mm", (128, 128))
27632787
cluster_shape_mn = kwargs.get("cluster_shape_mm", (1, 1))
27642788

2789+
alpha = kwargs.get("alpha")
2790+
alpha_dtype = kwargs.get("alpha_dtype")
2791+
27652792
# TODO(kaixih@nvidia): do we need `use_cuda_graph`?
27662793
wrapper = MaskedBatchedMatmulCuteDSL(use_cuda_graph=False)
27672794
wrapper.run(
@@ -2784,4 +2811,6 @@ def grouped_gemm_nt_masked(
27842811
sfb_tensor_gpu=sfb_torch,
27852812
c_tensor_gpu=c_torch,
27862813
masked_m_tensor_gpu=masked_m,
2814+
alpha_dtype=get_cutlass_dtype(alpha_dtype),
2815+
alpha_tensor_gpu=alpha,
27872816
)

β€Žtests/test_cute_dsl_blockscaled_gemm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
@pytest.mark.parametrize("a_major", ["k"])
5454
@pytest.mark.parametrize("b_major", ["k"])
5555
@pytest.mark.parametrize("c_major", ["n"])
56+
@pytest.mark.parametrize("fuse_alpha", [False, True])
57+
@pytest.mark.parametrize("alpha_dtype", ["float32"])
5658
@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
5759
@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
5860
@pytest.mark.parametrize("tolerance", [1e-01])
@@ -67,6 +69,8 @@ def test_blockscaled_gemm_python_interface(
6769
a_major: str,
6870
b_major: str,
6971
c_major: str,
72+
fuse_alpha: bool,
73+
alpha_dtype: cutlass.dtype,
7074
mma_tiler_mn: Tuple[int, int],
7175
cluster_shape_mn: Tuple[int, int],
7276
tolerance: float,
@@ -164,6 +168,9 @@ def create_torch_tensor(l, mode0, mode1, is_mode0_major, cutlass_dtype, device):
164168
is_dynamic_layout=True,
165169
assumed_align=16,
166170
)
171+
alpha_tensor = (
172+
torch.randn(l, dtype=torch.float32, device="cuda") if fuse_alpha else None
173+
)
167174

168175
# for deepgemm-like python interface
169176
if ab_dtype == "float4_e2m1fn":
@@ -206,13 +213,18 @@ def create_torch_tensor(l, mode0, mode1, is_mode0_major, cutlass_dtype, device):
206213
sf_vec_size=sf_vec_size,
207214
mma_tiler_mn=mma_tiler_mn,
208215
cluster_shape_mn=cluster_shape_mn,
216+
alpha=alpha_tensor,
217+
alpha_dtype=alpha_dtype,
209218
)
210219
torch.cuda.synchronize()
211220

212221
# compute ref output
222+
if not fuse_alpha:
223+
alpha_tensor = torch.ones(l, dtype=torch.float32, device="cuda")
213224
res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref)
214225
res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref)
215226
ref = torch.einsum("mkl,nkl->mnl", res_a, res_b)
227+
ref = torch.einsum("mnl,l->mnl", ref, alpha_tensor.cpu())
216228

217229
# Convert c back to f32 for comparison.
218230
c_ref_device = c_ref.cuda()

0 commit comments

Comments
Β (0)