Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 19b82e0

Browse files
committed
Add rowwwise scaling to Float8Inference module
ghstack-source-id: e5e6c73 Pull Request resolved: #305
1 parent 52e5d0a commit 19b82e0

File tree

4 files changed

+111
-22
lines changed

4 files changed

+111
-22
lines changed

float8_experimental/float8_python_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def addmm_float8_unwrapped(
5151
)
5252
output += bias
5353
return output
54+
# Weight tensors are stored in N, K format. We call tensor_to_scale(dim=0)
55+
# which produces a (N, 1) Tensor. However scaled_mm syntactically expects
56+
# M X K @ K X N, and scales (M, 1) and (1, N)
57+
b_inverse_scale = (
58+
b_inverse_scale.T if b_inverse_scale.dim() == 2 else b_inverse_scale
59+
)
60+
5461
output = torch._scaled_mm(
5562
a_data,
5663
b_data,

float8_experimental/float8_tensor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def to_fp8_no_autograd(
109109
mm_config: Defines the configuration for the scaled_mm
110110
"""
111111

112-
x_scaled = x * x_scale
112+
x_scaled = x * x_scale.to(dtype=x.dtype)
113113
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
114114

115115
if isinstance(bits_fp8, DTensor):
@@ -195,7 +195,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function):
195195

196196
@staticmethod
197197
def forward(ctx, tensor):
198-
return tensor._data.to(tensor._orig_dtype) / tensor._scale
198+
return tensor._data.to(tensor._orig_dtype) / tensor._scale.to(
199+
tensor._orig_dtype
200+
)
199201

200202
@staticmethod
201203
def backward(ctx, g):
@@ -253,11 +255,11 @@ def __init__(
253255
orig_dtype: torch.dtype,
254256
mm_config: Optional[ScaledMMConfig],
255257
):
256-
assert (
257-
scale.numel() == 1
258-
), "Scale should contain a single value, but got: {} elements".format(
259-
scale.numel()
260-
)
258+
# assert (
259+
# scale.numel() == 1
260+
# ), "Scale should contain a single value, but got: {} elements".format(
261+
# scale.numel()
262+
# )
261263

262264
self._data = data
263265
self._scale = scale

float8_experimental/float8_utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Iterable, Literal, Tuple, Union
7+
from typing import Iterable, Literal, Optional, Tuple, Union
88

99
import float8_experimental.config as config
1010
import torch
@@ -32,6 +32,12 @@
3232
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
3333

3434

35+
def get_supported_granularity():
36+
from float8_experimental.float8_tensor import ScalingGranularity
37+
38+
return [ScalingGranularity.TensorWise, ScalingGranularity.AxisWise]
39+
40+
3541
@torch.no_grad()
3642
def amax_to_scale(
3743
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@@ -103,20 +109,34 @@ def amax_history_to_scale_stack(
103109
def tensor_to_amax(
104110
x: torch.Tensor,
105111
scaling_granularity,
112+
dim: Optional[int] = None,
106113
reduce_amax: bool = False,
107114
) -> torch.Tensor:
108115
"""Calculates the amax of a tensor.
109116
Args:
110117
x: The tensor to calculate the amax for.
111118
scaling_granularity: The granularity of with which to calcualte the tensor amax
119+
dim: The dimension along which to calculate the amax. This is only used if scaling_granularity is AxisWise.
112120
reduce_amax: Whether to perform a distributed reduction on the amax.
113121
"""
114122
from float8_experimental.float8_tensor import ScalingGranularity
115123

116-
assert (
117-
scaling_granularity == ScalingGranularity.TensorWise
118-
), f"Currently only TensorWise is supported for but given scaling_granularity: {scaling_granularity}"
119-
amax = torch.max(torch.abs(x))
124+
supported_granularities = get_supported_granularity()
125+
126+
if scaling_granularity not in supported_granularities:
127+
raise ValueError(
128+
f"Currently only {supported_granularities} are supported. Given scaling_granularity: {scaling_granularity}"
129+
)
130+
131+
if scaling_granularity == ScalingGranularity.TensorWise:
132+
amax = torch.max(torch.abs(x))
133+
elif scaling_granularity == ScalingGranularity.AxisWise:
134+
if dim is None:
135+
raise ValueError("For AxisWise scaling, a dim must be passed in!")
136+
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
137+
else:
138+
# This should never be reached due to the earlier check, but it's here for completeness
139+
raise ValueError(f"Unsupported scaling_granularity: {scaling_granularity}")
120140

121141
# If the user asked for distributed reduction, do it.
122142
# If the user did not ask for it, assume that it will
@@ -132,16 +152,20 @@ def tensor_to_scale(
132152
x: torch.Tensor,
133153
float8_dtype: torch.dtype,
134154
scaling_granularity,
155+
dim: Optional[int] = None,
135156
reduce_amax: bool = False,
157+
collapse_leading_dims: bool = False,
136158
) -> torch.Tensor:
137159
"""Calculates the scale that will be used for quantization to Float8Tensor
138160
Args:
139161
x: The tensor to calculate the scale for.
140162
float8_dtype: The Float8 dtype to use.
141163
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
164+
dim: The dimension along which to calculate the scale. This is only used if scaling_granularity is AxisWise.
142165
reduce_amax: Whether to perform a distributed reduction on the amax.
166+
collapse_leading_dims: Whether to collapse leading dimensions of the tensor.
143167
"""
144-
amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
168+
amax = tensor_to_amax(x, scaling_granularity, dim=dim, reduce_amax=reduce_amax)
145169
return amax_to_scale(amax, float8_dtype, x.dtype)
146170

147171

float8_experimental/inference.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
tensor_already_casted_to_fp8,
2626
to_fp8_no_autograd,
2727
)
28-
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
28+
from float8_experimental.float8_utils import (
29+
e4m3_dtype,
30+
get_supported_granularity,
31+
tensor_to_scale,
32+
)
33+
34+
SUPPORTED_GRANULARITY = get_supported_granularity()
2935

3036

3137
class ActivationCasting(Enum):
@@ -75,7 +81,7 @@ def __init__(
7581
# FP8 specific arguments
7682
quant_config: QuantConfig,
7783
forward_config: ScaledMMConfig,
78-
scaling_granularity: ScalingGranularity,
84+
scaling_granularity: Optional[ScalingGranularity],
7985
# nn.Linear arguments
8086
in_features: int,
8187
out_features: int,
@@ -86,7 +92,26 @@ def __init__(
8692
# Construct the superclass this will create dummy weights and biases
8793
super().__init__(in_features, out_features, bias, device, dtype)
8894
self.forward_config = forward_config
89-
self.scaling_granularity = scaling_granularity
95+
if scaling_granularity is None:
96+
self.scaling_granularity = (
97+
ScalingGranularity.AxisWise
98+
if dtype == torch.bfloat16
99+
and quant_config.static_quantization_scale is None
100+
else ScalingGranularity.TensorWise
101+
)
102+
else:
103+
assert (
104+
scaling_granularity in SUPPORTED_GRANULARITY
105+
), f"scaling_granularity must be in {SUPPORTED_GRANULARITY} but got {scaling_granularity}"
106+
if (
107+
scaling_granularity == ScalingGranularity.AxisWise
108+
and dtype != torch.bfloat16
109+
):
110+
raise ValueError(
111+
"AxisWise scaling granularity is only supported for bfloat16."
112+
)
113+
self.scaling_granularity = scaling_granularity
114+
90115
self.activation_casting = quant_config.activation_casting
91116
if self.activation_casting == ActivationCasting.STATIC:
92117
self.register_buffer(
@@ -101,13 +126,22 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
101126
input, self.weight.to_original_precision()
102127
)
103128

129+
# TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
130+
original_m = input.shape[:-1]
131+
input = input.view(-1, input.shape[-1])
132+
104133
x_fp8 = cast_to_float8_e4m3_inference(
105134
input,
106135
self.forward_config,
107136
static_quantization_scale=self.static_quantization_scale,
108137
scaling_granularity=self.scaling_granularity,
109138
)
110-
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
139+
return torch.nn.functional.linear(x_fp8, self.weight, self.bias).view(
140+
*original_m, -1
141+
)
142+
143+
def extra_repr(self):
144+
return f"{super().extra_repr()},activation_casting={self.activation_casting.name},scaling_granularity={self.scaling_granularity.name}"
111145

112146
# Builder functions for Float8LinearInference
113147
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
@@ -124,7 +158,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
124158
assert not isinstance(
125159
self.weight, Float8Tensor
126160
), "Weight has already been quantized, cannot quantize again."
127-
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity)
161+
162+
# For weight tensors + AxisWise we calculate scales along columns
163+
dim = None
164+
if self.scaling_granularity == ScalingGranularity.AxisWise:
165+
dim = 1
166+
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity, dim=dim)
128167
quantized_weight = to_fp8_no_autograd(
129168
self.weight, scale, dtype, self.forward_config
130169
)
@@ -143,19 +182,20 @@ def from_float(
143182
module: nn.Module,
144183
quant_config: QuantConfig,
145184
use_fast_accum: bool,
185+
scaling_granularity: Optional[ScalingGranularity],
146186
) -> "Float8InferenceLinear":
147187
"""
148188
Create an nn.Linear with fp8 compute from another nn.Linear
149189
150190
Args:
151191
mod (torch.nn.Linear): nn.Linear to convert
152192
quant_config (QuantConfig): Configuration for the weight and activation casting
193+
use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
194+
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
153195
"""
154196
forward_config = ScaledMMConfig(
155197
False, use_fast_accum, pad_inner_dim=config.pad_inner_dim
156198
)
157-
# TODO: For now hardcode TensorWise scaling
158-
scaling_granularity = ScalingGranularity.TensorWise
159199
linear = cls(
160200
quant_config,
161201
forward_config,
@@ -164,6 +204,7 @@ def from_float(
164204
module.out_features,
165205
False,
166206
device=torch.device("meta"),
207+
dtype=module.weight.dtype,
167208
)
168209
linear.set_weight_and_bias(module.weight, module.bias)
169210
linear.quantize_weight()
@@ -194,18 +235,29 @@ def cast_to_float8_e4m3_inference(
194235
"""
195236
if tensor_already_casted_to_fp8(inpt_tensor):
196237
return inpt_tensor
238+
239+
# For input tensors + AxisWise we calculate scales along rows
240+
dim = None
241+
if scaling_granularity == ScalingGranularity.AxisWise:
242+
dim = 1
243+
197244
scale = (
198245
static_quantization_scale
199246
if static_quantization_scale is not None
200247
else tensor_to_scale(
201-
inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax
248+
inpt_tensor,
249+
e4m3_dtype,
250+
scaling_granularity,
251+
dim=dim,
252+
reduce_amax=reduce_amax,
202253
)
203254
)
204255
return Float8Tensor.to_float8(
205256
inpt_tensor,
206257
scale,
207258
e4m3_dtype,
208259
mm_config=mm_config,
260+
scaling_granularity=scaling_granularity,
209261
)
210262

211263

@@ -215,6 +267,7 @@ def quantize_to_float8(
215267
*,
216268
skip_fqn_list: Optional[List[str]] = None,
217269
use_fast_accum: bool = True,
270+
scaling_granularity: Optional[ScalingGranularity] = None,
218271
) -> Optional[nn.Module]:
219272
"""
220273
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
@@ -228,6 +281,7 @@ def quantize_to_float8(
228281
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229282
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
230283
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
284+
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
231285
232286
Returns:
233287
nn.Module: The modified module with applicable Linear layers converted to Float8.
@@ -237,6 +291,8 @@ def quantize_to_float8(
237291
"""
238292
return swap_linear_layers(
239293
module,
240-
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
294+
lambda m: Float8InferenceLinear.from_float(
295+
m, quant_config, use_fast_accum, scaling_granularity
296+
),
241297
skip_fqn_list=skip_fqn_list,
242298
)

0 commit comments

Comments
 (0)