Skip to content

Commit 7e4fb0d

Browse files
committed
add some miss file
1 parent 1cfc09c commit 7e4fb0d

File tree

3 files changed

+571
-0
lines changed

3 files changed

+571
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
17+
original_linear = paddle.nn.functional.linear
18+
19+
from typing import Literal, Optional
20+
21+
# from ..linear_utils import RowParallelLinear as PD_RowParallelLinear
22+
from ..linear_utils import ColumnParallelLinear as PD_ColumnParallelLinear
23+
from ..linear_utils import (
24+
ColumnSequenceParallelLinear as PD_ColumnSequenceParallelLinear,
25+
)
26+
from ..linear_utils import Linear as PD_Linear
27+
from ..linear_utils import RowParallelLinear as PD_RowParallelLinear
28+
from ..linear_utils import RowSequenceParallelLinear as PD_RowSequenceParallelLinear
29+
30+
try:
31+
from .kernel import act_quant, fp8_gemm, weight_dequant
32+
except:
33+
pass
34+
35+
36+
__all__ = [
37+
"Linear",
38+
"ColumnParallelLinear",
39+
"RowParallelLinear",
40+
"ColumnSequenceParallelLinear",
41+
"RowSequenceParallelLinear",
42+
]
43+
44+
gemm_impl: Literal["bf16", "fp8"] = "bf16"
45+
block_size = 128
46+
47+
48+
def fp8_linear(
49+
x: paddle.Tensor, weight: paddle.Tensor, bias: Optional[paddle.Tensor] = None, name=None
50+
) -> paddle.Tensor:
51+
"""
52+
Applies a linear transformation to the incoming data: y = xA^T + b.
53+
This function supports specialized implementations based on quantization
54+
and tensor formats.
55+
56+
Args:
57+
x (paddle.Tensor): The input tensor.
58+
weight (paddle.Tensor): The weight tensor. It may be quantized and
59+
requires dequantization for certain cases.
60+
bias (Optional[paddle.Tensor]): The bias tensor to be added. Default is None.
61+
62+
Returns:
63+
paddle.Tensor: The result of the linear transformation, which may involve
64+
quantization-aware computations depending on the input parameters.
65+
66+
Notes:
67+
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
68+
is used for computation.
69+
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
70+
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
71+
"""
72+
73+
if paddle.in_dynamic_mode():
74+
if weight.element_size() > 1:
75+
return original_linear(x, weight, bias)
76+
elif gemm_impl == "bf16":
77+
weight = weight_dequant(weight, weight._scale)
78+
return original_linear(x, weight, bias)
79+
else:
80+
x, scale = act_quant(x, block_size)
81+
y = fp8_gemm(x, scale, weight, weight._scale)
82+
if bias is not None:
83+
y += bias
84+
return y
85+
else:
86+
return original_linear(x, weight, bias)
87+
88+
89+
paddle.nn.functional.linear = fp8_linear
90+
91+
92+
def register_scale(self):
93+
if self.weight.element_size() == 1:
94+
in_features, out_features = self.weight.shape
95+
scale_out_features = (out_features + self.block_size - 1) // self.block_size
96+
scale_in_features = (in_features + self.block_size - 1) // self.block_size
97+
self.weight_scale_inv = self.create_parameter(
98+
shape=[scale_in_features, scale_out_features],
99+
attr=self._weight_attr,
100+
dtype="float32",
101+
is_bias=False,
102+
)
103+
self.weight._scale = self.weight_scale_inv
104+
105+
106+
class Linear(PD_Linear):
107+
def __init__(self, *args, **kwargs):
108+
super().__init__(*args, **kwargs)
109+
self.block_size = kwargs.get("block_size", 128)
110+
register_scale(self)
111+
112+
113+
class ColumnParallelLinear(PD_ColumnParallelLinear):
114+
def __init__(self, *args, **kwargs):
115+
super().__init__(*args, **kwargs)
116+
self.block_size = kwargs.get("block_size", 128)
117+
register_scale(self)
118+
119+
120+
class RowParallelLinear(PD_RowParallelLinear):
121+
def __init__(self, *args, **kwargs):
122+
super().__init__(*args, **kwargs)
123+
self.block_size = kwargs.get("block_size", 128)
124+
register_scale(self)
125+
126+
127+
class ColumnSequenceParallelLinear(PD_ColumnSequenceParallelLinear):
128+
def __init__(self, *args, **kwargs):
129+
super().__init__(*args, **kwargs)
130+
self.block_size = kwargs.get("block_size", 128)
131+
register_scale(self)
132+
133+
134+
class RowSequenceParallelLinear(PD_RowSequenceParallelLinear):
135+
def __init__(self, *args, **kwargs):
136+
super().__init__(*args, **kwargs)
137+
self.block_size = kwargs.get("block_size", 128)
138+
register_scale(self)
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright (c) 2023 DeepSeek. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Tuple
17+
18+
import paddle
19+
import triton
20+
import triton.language as tl
21+
22+
# from triton import Config
23+
24+
25+
@triton.jit
26+
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
27+
"""
28+
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
29+
30+
Args:
31+
x_ptr (triton.Pointer): Pointer to the input tensor.
32+
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
33+
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
34+
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
35+
36+
Returns:
37+
None
38+
"""
39+
pid = tl.program_id(axis=0)
40+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
41+
x = tl.load(x_ptr + offs).to(tl.float32)
42+
s = tl.max(tl.abs(x)) / 448.0
43+
y = x / s
44+
y = y.to(y_ptr.dtype.element_ty)
45+
tl.store(y_ptr + offs, y)
46+
tl.store(s_ptr + pid, s)
47+
48+
49+
def act_quant(x: paddle.Tensor, block_size: int = 128) -> Tuple[paddle.Tensor, paddle.Tensor]:
50+
"""
51+
Quantizes the input tensor `x` using block-wise quantization.
52+
53+
Args:
54+
x (paddle.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
55+
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
56+
57+
Returns:
58+
Tuple[paddle.Tensor, paddle.Tensor]: A tuple containing:
59+
- The quantized tensor with dtype `paddle.float8_e4m3fn`.
60+
- A tensor of scaling factors with dtype `paddle.float32`.
61+
"""
62+
assert x.is_contiguous(), "Input tensor must be contiguous"
63+
assert (
64+
x.shape[-1] % block_size == 0
65+
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
66+
y = paddle.empty_like(x, dtype=paddle.float8_e4m3fn)
67+
s = paddle.empty((*x.shape[:-1], x.shape[-1] // block_size), dtype=paddle.float32)
68+
grid = lambda meta: (triton.cdiv(x.numel().item(), meta["BLOCK_SIZE"]),)
69+
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
70+
return y, s
71+
72+
73+
@triton.jit
74+
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
75+
"""
76+
Dequantizes weights using the provided scaling factors and stores the result.
77+
78+
Args:
79+
x_ptr (tl.pointer): Pointer to the quantized weights.
80+
s_ptr (tl.pointer): Pointer to the scaling factors.
81+
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
82+
M (int): Number of rows in the weight matrix.
83+
N (int): Number of columns in the weight matrix.
84+
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
85+
86+
Returns:
87+
None
88+
"""
89+
pid_m = tl.program_id(axis=0)
90+
pid_n = tl.program_id(axis=1)
91+
n = tl.cdiv(N, BLOCK_SIZE)
92+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
93+
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
94+
offs = offs_m[:, None] * N + offs_n[None, :]
95+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
96+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
97+
s = tl.load(s_ptr + pid_m * n + pid_n)
98+
y = x * s
99+
tl.store(y_ptr + offs, y, mask=mask)
100+
101+
102+
def weight_dequant(x: paddle.Tensor, s: paddle.Tensor, block_size: int = 128) -> paddle.Tensor:
103+
"""
104+
Dequantizes the given weight tensor using the provided scale tensor.
105+
106+
Args:
107+
x (paddle.Tensor): The quantized weight tensor of shape (M, N).
108+
s (paddle.Tensor): The scale tensor of shape (M, N).
109+
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
110+
111+
Returns:
112+
paddle.Tensor: The dequantized weight tensor of the same shape as `x`.
113+
114+
Raises:
115+
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
116+
"""
117+
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
118+
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
119+
M, N = x.shape
120+
y = paddle.empty_like(x, dtype=paddle.get_default_dtype())
121+
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
122+
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
123+
return y
124+
125+
126+
# fp8_gemm_configs = [
127+
# Config({"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, num_stages=num_stages, num_warps=8)
128+
# for block_m in [16, 32, 64]
129+
# for block_n in [32, 64, 128]
130+
# for num_stages in [3, 4, 5, 6]
131+
# ]
132+
# FIXME @ZHUI, paddle not support triton autotune temporarily.
133+
# # @triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
134+
@triton.jit
135+
def fp8_gemm_kernel(
136+
a_ptr,
137+
b_ptr,
138+
c_ptr,
139+
a_s_ptr,
140+
b_s_ptr,
141+
M,
142+
N: tl.constexpr,
143+
K: tl.constexpr,
144+
BLOCK_SIZE_M: tl.constexpr,
145+
BLOCK_SIZE_N: tl.constexpr,
146+
BLOCK_SIZE_K: tl.constexpr,
147+
):
148+
"""
149+
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
150+
151+
Args:
152+
a_ptr (tl.tensor): Pointer to the first input matrix A.
153+
b_ptr (tl.tensor): Pointer to the second input matrix B.
154+
c_ptr (tl.tensor): Pointer to the output matrix C.
155+
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
156+
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
157+
M (int): Number of rows in matrix A and C.
158+
N (tl.constexpr): Number of columns in matrix B and C.
159+
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
160+
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
161+
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
162+
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
163+
164+
Returns:
165+
None
166+
"""
167+
pid_m = tl.program_id(axis=0)
168+
pid_n = tl.program_id(axis=1)
169+
k = tl.cdiv(K, BLOCK_SIZE_K)
170+
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
171+
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
172+
offs_k = tl.arange(0, BLOCK_SIZE_K)
173+
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
174+
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
175+
a_s_ptrs = a_s_ptr + offs_m * k
176+
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
177+
178+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
179+
for i in range(k):
180+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
181+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
182+
a_s = tl.load(a_s_ptrs)
183+
b_s = tl.load(b_s_ptrs)
184+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
185+
a_ptrs += BLOCK_SIZE_K
186+
b_ptrs += BLOCK_SIZE_K
187+
a_s_ptrs += 1
188+
b_s_ptrs += 1
189+
c = accumulator.to(c_ptr.dtype.element_ty)
190+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
191+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
192+
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
193+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
194+
tl.store(c_ptrs, c, mask=mask)
195+
196+
197+
def fp8_gemm(a: paddle.Tensor, a_s: paddle.Tensor, b: paddle.Tensor, b_s: paddle.Tensor):
198+
"""
199+
Modified for B matrix with shape [K, N]
200+
"""
201+
# FIXME @ZHUI, transposed
202+
b = b.T.contiguous()
203+
b_s = b_s.T.contiguous()
204+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
205+
assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling factor tensors must be contiguous"
206+
207+
K = a.shape[-1]
208+
M = a.numel().item() // K
209+
# N = b.shape[-1] # Get N from the second dimension of B
210+
N = b.shape[0] # Get N from the second dimension of B
211+
212+
c = paddle.empty((*a.shape[:-1], N), dtype=paddle.get_default_dtype())
213+
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]))
214+
fp8_gemm_kernel[grid](
215+
a,
216+
b,
217+
c,
218+
a_s,
219+
b_s,
220+
M,
221+
N,
222+
K,
223+
BLOCK_SIZE_M=32,
224+
BLOCK_SIZE_N=64,
225+
BLOCK_SIZE_K=128,
226+
)
227+
return c

0 commit comments

Comments
 (0)