|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# Copyright (c) 2023 DeepSeek. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from typing import Tuple |
| 17 | + |
| 18 | +import paddle |
| 19 | +import triton |
| 20 | +import triton.language as tl |
| 21 | + |
| 22 | +# from triton import Config |
| 23 | + |
| 24 | + |
| 25 | +@triton.jit |
| 26 | +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): |
| 27 | + """ |
| 28 | + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. |
| 29 | +
|
| 30 | + Args: |
| 31 | + x_ptr (triton.Pointer): Pointer to the input tensor. |
| 32 | + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. |
| 33 | + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. |
| 34 | + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + None |
| 38 | + """ |
| 39 | + pid = tl.program_id(axis=0) |
| 40 | + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 41 | + x = tl.load(x_ptr + offs).to(tl.float32) |
| 42 | + s = tl.max(tl.abs(x)) / 448.0 |
| 43 | + y = x / s |
| 44 | + y = y.to(y_ptr.dtype.element_ty) |
| 45 | + tl.store(y_ptr + offs, y) |
| 46 | + tl.store(s_ptr + pid, s) |
| 47 | + |
| 48 | + |
| 49 | +def act_quant(x: paddle.Tensor, block_size: int = 128) -> Tuple[paddle.Tensor, paddle.Tensor]: |
| 50 | + """ |
| 51 | + Quantizes the input tensor `x` using block-wise quantization. |
| 52 | +
|
| 53 | + Args: |
| 54 | + x (paddle.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. |
| 55 | + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + Tuple[paddle.Tensor, paddle.Tensor]: A tuple containing: |
| 59 | + - The quantized tensor with dtype `paddle.float8_e4m3fn`. |
| 60 | + - A tensor of scaling factors with dtype `paddle.float32`. |
| 61 | + """ |
| 62 | + assert x.is_contiguous(), "Input tensor must be contiguous" |
| 63 | + assert ( |
| 64 | + x.shape[-1] % block_size == 0 |
| 65 | + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" |
| 66 | + y = paddle.empty_like(x, dtype=paddle.float8_e4m3fn) |
| 67 | + s = paddle.empty((*x.shape[:-1], x.shape[-1] // block_size), dtype=paddle.float32) |
| 68 | + grid = lambda meta: (triton.cdiv(x.numel().item(), meta["BLOCK_SIZE"]),) |
| 69 | + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) |
| 70 | + return y, s |
| 71 | + |
| 72 | + |
| 73 | +@triton.jit |
| 74 | +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): |
| 75 | + """ |
| 76 | + Dequantizes weights using the provided scaling factors and stores the result. |
| 77 | +
|
| 78 | + Args: |
| 79 | + x_ptr (tl.pointer): Pointer to the quantized weights. |
| 80 | + s_ptr (tl.pointer): Pointer to the scaling factors. |
| 81 | + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. |
| 82 | + M (int): Number of rows in the weight matrix. |
| 83 | + N (int): Number of columns in the weight matrix. |
| 84 | + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + None |
| 88 | + """ |
| 89 | + pid_m = tl.program_id(axis=0) |
| 90 | + pid_n = tl.program_id(axis=1) |
| 91 | + n = tl.cdiv(N, BLOCK_SIZE) |
| 92 | + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 93 | + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 94 | + offs = offs_m[:, None] * N + offs_n[None, :] |
| 95 | + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| 96 | + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) |
| 97 | + s = tl.load(s_ptr + pid_m * n + pid_n) |
| 98 | + y = x * s |
| 99 | + tl.store(y_ptr + offs, y, mask=mask) |
| 100 | + |
| 101 | + |
| 102 | +def weight_dequant(x: paddle.Tensor, s: paddle.Tensor, block_size: int = 128) -> paddle.Tensor: |
| 103 | + """ |
| 104 | + Dequantizes the given weight tensor using the provided scale tensor. |
| 105 | +
|
| 106 | + Args: |
| 107 | + x (paddle.Tensor): The quantized weight tensor of shape (M, N). |
| 108 | + s (paddle.Tensor): The scale tensor of shape (M, N). |
| 109 | + block_size (int, optional): The block size to use for dequantization. Defaults to 128. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + paddle.Tensor: The dequantized weight tensor of the same shape as `x`. |
| 113 | +
|
| 114 | + Raises: |
| 115 | + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. |
| 116 | + """ |
| 117 | + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" |
| 118 | + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" |
| 119 | + M, N = x.shape |
| 120 | + y = paddle.empty_like(x, dtype=paddle.get_default_dtype()) |
| 121 | + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) |
| 122 | + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) |
| 123 | + return y |
| 124 | + |
| 125 | + |
| 126 | +# fp8_gemm_configs = [ |
| 127 | +# Config({"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, num_stages=num_stages, num_warps=8) |
| 128 | +# for block_m in [16, 32, 64] |
| 129 | +# for block_n in [32, 64, 128] |
| 130 | +# for num_stages in [3, 4, 5, 6] |
| 131 | +# ] |
| 132 | +# FIXME @ZHUI, paddle not support triton autotune temporarily. |
| 133 | +# # @triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) |
| 134 | +@triton.jit |
| 135 | +def fp8_gemm_kernel( |
| 136 | + a_ptr, |
| 137 | + b_ptr, |
| 138 | + c_ptr, |
| 139 | + a_s_ptr, |
| 140 | + b_s_ptr, |
| 141 | + M, |
| 142 | + N: tl.constexpr, |
| 143 | + K: tl.constexpr, |
| 144 | + BLOCK_SIZE_M: tl.constexpr, |
| 145 | + BLOCK_SIZE_N: tl.constexpr, |
| 146 | + BLOCK_SIZE_K: tl.constexpr, |
| 147 | +): |
| 148 | + """ |
| 149 | + Performs a matrix multiplication operation on FP8 matrices with scaling factors. |
| 150 | +
|
| 151 | + Args: |
| 152 | + a_ptr (tl.tensor): Pointer to the first input matrix A. |
| 153 | + b_ptr (tl.tensor): Pointer to the second input matrix B. |
| 154 | + c_ptr (tl.tensor): Pointer to the output matrix C. |
| 155 | + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. |
| 156 | + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. |
| 157 | + M (int): Number of rows in matrix A and C. |
| 158 | + N (tl.constexpr): Number of columns in matrix B and C. |
| 159 | + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. |
| 160 | + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. |
| 161 | + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. |
| 162 | + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. |
| 163 | +
|
| 164 | + Returns: |
| 165 | + None |
| 166 | + """ |
| 167 | + pid_m = tl.program_id(axis=0) |
| 168 | + pid_n = tl.program_id(axis=1) |
| 169 | + k = tl.cdiv(K, BLOCK_SIZE_K) |
| 170 | + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 171 | + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 172 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 173 | + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] |
| 174 | + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] |
| 175 | + a_s_ptrs = a_s_ptr + offs_m * k |
| 176 | + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k |
| 177 | + |
| 178 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 179 | + for i in range(k): |
| 180 | + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) |
| 181 | + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) |
| 182 | + a_s = tl.load(a_s_ptrs) |
| 183 | + b_s = tl.load(b_s_ptrs) |
| 184 | + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] |
| 185 | + a_ptrs += BLOCK_SIZE_K |
| 186 | + b_ptrs += BLOCK_SIZE_K |
| 187 | + a_s_ptrs += 1 |
| 188 | + b_s_ptrs += 1 |
| 189 | + c = accumulator.to(c_ptr.dtype.element_ty) |
| 190 | + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 191 | + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 192 | + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] |
| 193 | + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| 194 | + tl.store(c_ptrs, c, mask=mask) |
| 195 | + |
| 196 | + |
| 197 | +def fp8_gemm(a: paddle.Tensor, a_s: paddle.Tensor, b: paddle.Tensor, b_s: paddle.Tensor): |
| 198 | + """ |
| 199 | + Modified for B matrix with shape [K, N] |
| 200 | + """ |
| 201 | + # FIXME @ZHUI, transposed |
| 202 | + b = b.T.contiguous() |
| 203 | + b_s = b_s.T.contiguous() |
| 204 | + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" |
| 205 | + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling factor tensors must be contiguous" |
| 206 | + |
| 207 | + K = a.shape[-1] |
| 208 | + M = a.numel().item() // K |
| 209 | + # N = b.shape[-1] # Get N from the second dimension of B |
| 210 | + N = b.shape[0] # Get N from the second dimension of B |
| 211 | + |
| 212 | + c = paddle.empty((*a.shape[:-1], N), dtype=paddle.get_default_dtype()) |
| 213 | + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) |
| 214 | + fp8_gemm_kernel[grid]( |
| 215 | + a, |
| 216 | + b, |
| 217 | + c, |
| 218 | + a_s, |
| 219 | + b_s, |
| 220 | + M, |
| 221 | + N, |
| 222 | + K, |
| 223 | + BLOCK_SIZE_M=32, |
| 224 | + BLOCK_SIZE_N=64, |
| 225 | + BLOCK_SIZE_K=128, |
| 226 | + ) |
| 227 | + return c |
0 commit comments