Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 29c28ed

Browse files
committed
Add rowwwise scaling to Float8Inference module
ghstack-source-id: 20cfa0a Pull Request resolved: #305
1 parent 653e120 commit 29c28ed

File tree

4 files changed

+102
-23
lines changed

4 files changed

+102
-23
lines changed

float8_experimental/float8_python_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def addmm_float8_unwrapped(
5555
a_data,
5656
b_data,
5757
scale_a=a_inverse_scale,
58-
scale_b=b_inverse_scale,
58+
scale_b=b_inverse_scale.T,
5959
bias=bias,
6060
scale_result=output_scale,
6161
out_dtype=output_dtype,

float8_experimental/float8_tensor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def to_fp8_no_autograd(
109109
mm_config: Defines the configuration for the scaled_mm
110110
"""
111111

112-
x_scaled = x * x_scale
112+
x_scaled = x * x_scale.to(dtype=x.dtype)
113113
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
114114

115115
if isinstance(bits_fp8, DTensor):
@@ -195,7 +195,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function):
195195

196196
@staticmethod
197197
def forward(ctx, tensor):
198-
return tensor._data.to(tensor._orig_dtype) / tensor._scale
198+
return tensor._data.to(tensor._orig_dtype) / tensor._scale.to(
199+
tensor._orig_dtype
200+
)
199201

200202
@staticmethod
201203
def backward(ctx, g):
@@ -253,11 +255,11 @@ def __init__(
253255
orig_dtype: torch.dtype,
254256
mm_config: Optional[ScaledMMConfig],
255257
):
256-
assert (
257-
scale.numel() == 1
258-
), "Scale should contain a single value, but got: {} elements".format(
259-
scale.numel()
260-
)
258+
# assert (
259+
# scale.numel() == 1
260+
# ), "Scale should contain a single value, but got: {} elements".format(
261+
# scale.numel()
262+
# )
261263

262264
self._data = data
263265
self._scale = scale

float8_experimental/float8_utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Iterable, Literal, Tuple, Union
7+
from typing import Iterable, Literal, Optional, Tuple, Union
88

99
import float8_experimental.config as config
1010
import torch
@@ -32,6 +32,12 @@
3232
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
3333

3434

35+
def get_supported_granularity():
36+
from float8_experimental.float8_tensor import ScalingGranularity
37+
38+
return [ScalingGranularity.TensorWise, ScalingGranularity.AxisWise]
39+
40+
3541
@torch.no_grad()
3642
def amax_to_scale(
3743
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@@ -103,20 +109,34 @@ def amax_history_to_scale_stack(
103109
def tensor_to_amax(
104110
x: torch.Tensor,
105111
scaling_granularity,
112+
dim: Optional[int] = None,
106113
reduce_amax: bool = False,
107114
) -> torch.Tensor:
108115
"""Calculates the amax of a tensor.
109116
Args:
110117
x: The tensor to calculate the amax for.
111118
scaling_granularity: The granularity of with which to calcualte the tensor amax
119+
dim: The dimension along which to calculate the amax. This is only used if scaling_granularity is AxisWise.
112120
reduce_amax: Whether to perform a distributed reduction on the amax.
113121
"""
114122
from float8_experimental.float8_tensor import ScalingGranularity
115123

116-
assert (
117-
scaling_granularity == ScalingGranularity.TensorWise
118-
), f"Currently only TensorWise is supported for but given scaling_granularity: {scaling_granularity}"
119-
amax = torch.max(torch.abs(x))
124+
supported_granularities = get_supported_granularity()
125+
126+
if scaling_granularity not in supported_granularities:
127+
raise ValueError(
128+
f"Currently only {supported_granularities} are supported. Given scaling_granularity: {scaling_granularity}"
129+
)
130+
131+
if scaling_granularity == ScalingGranularity.TensorWise:
132+
amax = torch.max(torch.abs(x))
133+
elif scaling_granularity == ScalingGranularity.AxisWise:
134+
if dim is None:
135+
raise ValueError("For AxisWise scaling, a dim must be passed in!")
136+
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
137+
else:
138+
# This should never be reached due to the earlier check, but it's here for completeness
139+
raise ValueError(f"Unsupported scaling_granularity: {scaling_granularity}")
120140

121141
# If the user asked for distributed reduction, do it.
122142
# If the user did not ask for it, assume that it will
@@ -132,16 +152,20 @@ def tensor_to_scale(
132152
x: torch.Tensor,
133153
float8_dtype: torch.dtype,
134154
scaling_granularity,
155+
dim: Optional[int] = None,
135156
reduce_amax: bool = False,
157+
collapse_leading_dims: bool = False,
136158
) -> torch.Tensor:
137159
"""Calculates the scale that will be used for quantization to Float8Tensor
138160
Args:
139161
x: The tensor to calculate the scale for.
140162
float8_dtype: The Float8 dtype to use.
141163
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
164+
dim: The dimension along which to calculate the scale. This is only used if scaling_granularity is AxisWise.
142165
reduce_amax: Whether to perform a distributed reduction on the amax.
166+
collapse_leading_dims: Whether to collapse leading dimensions of the tensor.
143167
"""
144-
amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
168+
amax = tensor_to_amax(x, scaling_granularity, dim=dim, reduce_amax=reduce_amax)
145169
return amax_to_scale(amax, float8_dtype, x.dtype)
146170

147171

float8_experimental/inference.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
tensor_already_casted_to_fp8,
2626
to_fp8_no_autograd,
2727
)
28-
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
28+
from float8_experimental.float8_utils import (
29+
e4m3_dtype,
30+
get_supported_granularity,
31+
tensor_to_scale,
32+
)
33+
34+
SUPPORTED_GRANULARITY = get_supported_granularity()
2935

3036

3137
class ActivationCasting(Enum):
@@ -75,7 +81,7 @@ def __init__(
7581
# FP8 specific arguments
7682
quant_config: QuantConfig,
7783
forward_config: ScaledMMConfig,
78-
scaling_granularity: ScalingGranularity,
84+
scaling_granularity: Optional[ScalingGranularity],
7985
# nn.Linear arguments
8086
in_features: int,
8187
out_features: int,
@@ -86,7 +92,26 @@ def __init__(
8692
# Construct the superclass this will create dummy weights and biases
8793
super().__init__(in_features, out_features, bias, device, dtype)
8894
self.forward_config = forward_config
89-
self.scaling_granularity = scaling_granularity
95+
if scaling_granularity is None:
96+
self.scaling_granularity = (
97+
ScalingGranularity.AxisWise
98+
if dtype == torch.bfloat16
99+
and quant_config.static_quantization_scale is None
100+
else ScalingGranularity.TensorWise
101+
)
102+
else:
103+
assert (
104+
scaling_granularity in SUPPORTED_GRANULARITY
105+
), f"scaling_granularity must be in {SUPPORTED_GRANULARITY} but got {scaling_granularity}"
106+
if (
107+
scaling_granularity == ScalingGranularity.AxisWise
108+
and dtype != torch.bfloat16
109+
):
110+
raise ValueError(
111+
"AxisWise scaling granularity is only supported for bfloat16."
112+
)
113+
self.scaling_granularity = scaling_granularity
114+
90115
self.activation_casting = quant_config.activation_casting
91116
if self.activation_casting == ActivationCasting.STATIC:
92117
self.register_buffer(
@@ -101,13 +126,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
101126
input, self.weight.to_original_precision()
102127
)
103128

129+
# TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
130+
original_m = input.shape[:-1]
131+
input = input.view(-1, input.shape[-1])
132+
104133
x_fp8 = cast_to_float8_e4m3_inference(
105134
input,
106135
self.forward_config,
107136
static_quantization_scale=self.static_quantization_scale,
108137
scaling_granularity=self.scaling_granularity,
109138
)
110-
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
139+
return torch.nn.functional.linear(x_fp8, self.weight, self.bias).view(
140+
*original_m, -1
141+
)
111142

112143
# Builder functions for Float8LinearInference
113144
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
@@ -124,7 +155,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
124155
assert not isinstance(
125156
self.weight, Float8Tensor
126157
), "Weight has already been quantized, cannot quantize again."
127-
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity)
158+
159+
# For weight tensors + AxisWise we calculate scales along columns
160+
dim = None
161+
if self.scaling_granularity == ScalingGranularity.AxisWise:
162+
dim = 1
163+
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity, dim=dim)
128164
quantized_weight = to_fp8_no_autograd(
129165
self.weight, scale, dtype, self.forward_config
130166
)
@@ -143,19 +179,20 @@ def from_float(
143179
module: nn.Module,
144180
quant_config: QuantConfig,
145181
use_fast_accum: bool,
182+
scaling_granularity: Optional[ScalingGranularity],
146183
) -> "Float8InferenceLinear":
147184
"""
148185
Create an nn.Linear with fp8 compute from another nn.Linear
149186
150187
Args:
151188
mod (torch.nn.Linear): nn.Linear to convert
152189
quant_config (QuantConfig): Configuration for the weight and activation casting
190+
use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
191+
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
153192
"""
154193
forward_config = ScaledMMConfig(
155194
False, use_fast_accum, pad_inner_dim=config.pad_inner_dim
156195
)
157-
# TODO: For now hardcode TensorWise scaling
158-
scaling_granularity = ScalingGranularity.TensorWise
159196
linear = cls(
160197
quant_config,
161198
forward_config,
@@ -164,6 +201,7 @@ def from_float(
164201
module.out_features,
165202
False,
166203
device=torch.device("meta"),
204+
dtype=module.weight.dtype,
167205
)
168206
linear.set_weight_and_bias(module.weight, module.bias)
169207
linear.quantize_weight()
@@ -194,18 +232,29 @@ def cast_to_float8_e4m3_inference(
194232
"""
195233
if tensor_already_casted_to_fp8(inpt_tensor):
196234
return inpt_tensor
235+
236+
# For input tensors + AxisWise we calculate scales along rows
237+
dim = None
238+
if scaling_granularity == ScalingGranularity.AxisWise:
239+
dim = 1
240+
197241
scale = (
198242
static_quantization_scale
199243
if static_quantization_scale is not None
200244
else tensor_to_scale(
201-
inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax
245+
inpt_tensor,
246+
e4m3_dtype,
247+
scaling_granularity,
248+
dim=dim,
249+
reduce_amax=reduce_amax,
202250
)
203251
)
204252
return Float8Tensor.to_float8(
205253
inpt_tensor,
206254
scale,
207255
e4m3_dtype,
208256
mm_config=mm_config,
257+
scaling_granularity=scaling_granularity,
209258
)
210259

211260

@@ -215,6 +264,7 @@ def quantize_to_float8(
215264
*,
216265
skip_fqn_list: Optional[List[str]] = None,
217266
use_fast_accum: bool = True,
267+
scaling_granularity: Optional[ScalingGranularity] = None,
218268
) -> Optional[nn.Module]:
219269
"""
220270
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
@@ -228,6 +278,7 @@ def quantize_to_float8(
228278
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229279
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
230280
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
281+
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
231282
232283
Returns:
233284
nn.Module: The modified module with applicable Linear layers converted to Float8.
@@ -237,6 +288,8 @@ def quantize_to_float8(
237288
"""
238289
return swap_linear_layers(
239290
module,
240-
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
291+
lambda m: Float8InferenceLinear.from_float(
292+
m, quant_config, use_fast_accum, scaling_granularity
293+
),
241294
skip_fqn_list=skip_fqn_list,
242295
)

0 commit comments

Comments
 (0)