Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 288b752

Browse files
committed
Thread through the scaling type argument to float8 constructors
ghstack-source-id: aa6f0c0 Pull Request resolved: #301
1 parent 36405a7 commit 288b752

File tree

10 files changed

+217
-61
lines changed

10 files changed

+217
-61
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ repos:
1010
- id: trailing-whitespace
1111
- id: check-ast
1212
- id: check-merge-conflict
13-
- id: no-commit-to-branch
14-
args: ['--branch=main']
1513
- id: check-added-large-files
1614
args: ['--maxkb=500']
1715
- id: end-of-file-fixer

float8_experimental/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.float8_linear import Float8Linear
8-
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
8+
from float8_experimental.float8_tensor import (
9+
Float8Tensor,
10+
ScaledMMConfig,
11+
ScalingStrategy,
12+
)
913

1014
# Needed to load Float8Tensor with weights_only = True
1115
from torch.serialization import add_safe_globals
1216

13-
add_safe_globals([Float8Tensor, ScaledMMConfig])
17+
add_safe_globals([Float8Tensor, ScaledMMConfig, ScalingStrategy])
1418

1519
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_dynamic_linear.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Float8Tensor,
2020
merge_mm_configs,
2121
ScaledMMConfig,
22+
ScalingStrategy,
2223
tensor_already_casted_to_fp8,
2324
to_fp8_no_autograd,
2425
)
@@ -36,21 +37,27 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
3637
@staticmethod
3738
def forward(
3839
ctx,
39-
tensor,
40+
tensor: torch.Tensor,
4041
mm_config: ScaledMMConfig,
42+
scaling_strategy: ScalingStrategy,
4143
):
4244
ctx.mm_config = mm_config
45+
ctx.scaling_strategy = scaling_strategy
4346
return tensor
4447

4548
@staticmethod
46-
def backward(ctx, gradY):
49+
def backward(ctx, gradY: torch.Tensor):
4750
if tensor_already_casted_to_fp8(gradY):
48-
return gradY, None
51+
return gradY, None, None
4952
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
5053
fp8_tensor = to_fp8_no_autograd(
51-
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
54+
gradY,
55+
gradY_scale,
56+
e5m2_dtype,
57+
mm_config=ctx.mm_config,
58+
scaling_strategy=ctx.scaling_strategy,
5259
)
53-
return fp8_tensor, None
60+
return fp8_tensor, None, None
5461

5562

5663
class Float8DynamicLinear(torch.nn.Linear):
@@ -63,13 +70,15 @@ def __init__(self, **super_kwargs):
6370
super().__init__(**super_kwargs)
6471

6572
def forward(self, input: torch.Tensor) -> torch.Tensor:
66-
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
73+
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config, self.scaling_strategy)
6774
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6875
w_fp8 = self.weight
6976
else:
70-
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
77+
w_fp8 = cast_to_float8_e4m3fn(
78+
self.weight, self.forward_config, self.scaling_strategy
79+
)
7180
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
72-
y = cast_to_float8_e5m2_bw(y, self.backward_config)
81+
y = cast_to_float8_e5m2_bw(y, self.backward_config, self.scaling_strategy)
7382
return y
7483

7584
@classmethod
@@ -101,6 +110,9 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101110
fp8_output=False,
102111
pad_inner_dim=config.pad_inner_dim,
103112
)
113+
# TODO: For now hardcode TensorWise scaling
114+
new_mod.scaling_strategy = ScalingStrategy.TensorWise
115+
104116
if config.enable_fsdp_fp8_all_gather:
105117
new_mod.weight = nn.Parameter(
106118
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
@@ -112,18 +124,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112124

113125

114126
def cast_to_float8_e4m3fn(
115-
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
127+
inpt_tensor: torch.Tensor,
128+
mm_config: ScaledMMConfig,
129+
scaling_strategy: ScalingStrategy,
130+
reduce_amax: bool = False,
116131
) -> Float8Tensor:
117132
if tensor_already_casted_to_fp8(inpt_tensor):
118133
return inpt_tensor
119134
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
120-
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
135+
return Float8Tensor.to_float8(
136+
inpt_tensor,
137+
scale,
138+
e4m3_dtype,
139+
mm_config=mm_config,
140+
scaling_strategy=scaling_strategy,
141+
)
121142

122143

123144
def cast_to_float8_e5m2_bw(
124-
gradY: torch.Tensor, mm_config: ScaledMMConfig
145+
gradY: torch.Tensor, mm_config: ScaledMMConfig, scaling_strategy: ScalingStrategy
125146
) -> torch.Tensor:
126-
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
147+
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_strategy)
127148

128149

129150
# FSDP pads its local tensor on dim-0. The subclass should be preserved such

float8_experimental/float8_linear.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from float8_experimental.float8_tensor import (
1919
Float8Tensor,
2020
ScaledMMConfig,
21+
ScalingStrategy,
2122
to_fp8_no_autograd,
2223
)
2324

@@ -75,11 +76,13 @@ def forward(
7576
scale_fn_name,
7677
is_amax_initialized,
7778
mm_config: ScaledMMConfig,
79+
scaling_strategy: ScalingStrategy,
7880
):
7981
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
8082
ctx.scale_fn_name = scale_fn_name
8183
ctx.is_amax_initialized = is_amax_initialized
8284
ctx.mm_config = mm_config
85+
ctx.scaling_strategy = scaling_strategy
8386
return tensor
8487

8588
@staticmethod
@@ -102,9 +105,13 @@ def backward(ctx, go):
102105
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
103106

104107
res = to_fp8_no_autograd(
105-
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
108+
go,
109+
fp8_scale_dL_dY,
110+
e5m2_dtype,
111+
mm_config=ctx.mm_config,
112+
scaling_strategy=ctx.scaling_strategy,
106113
)
107-
empty_grads = None, None, None, None, None, None
114+
empty_grads = None, None, None, None, None, None, None
108115
return res, *empty_grads
109116

110117

@@ -150,6 +157,9 @@ def __init__(self, *args, **kwargs):
150157
self.forward_config = ScaledMMConfig()
151158
self.backward_config = ScaledMMConfig()
152159

160+
# Defines the scaling strategy for the forward and backwards pass
161+
self.scaling_strategy = ScalingStrategy.TensorWise
162+
153163
# Note: is_amax_initialized is not a buffer to avoid data dependent
154164
# control flow visible to dynamo
155165
# TODO(future PR): add serialization for this flag
@@ -288,6 +298,7 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
288298
scale_fn_name,
289299
self.is_amax_initialized,
290300
self.backward_config,
301+
self.scaling_strategy,
291302
)
292303
return y
293304

@@ -353,4 +364,6 @@ def from_float(cls, mod, emulate: bool = False):
353364
new_mod.backward_config = ScaledMMConfig(
354365
emulate, False, False, config.pad_inner_dim
355366
)
367+
# TODO: For now hardcode TensorWise scaling
368+
new_mod.scaling_strategy = ScalingStrategy.TensorWise
356369
return new_mod

float8_experimental/float8_ops.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def decorator(func):
5050
def float8_desugar_op(aten_op, args, kwargs=None):
5151
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
5252
return Float8Tensor(
53-
new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
53+
new_data,
54+
args[0]._scale,
55+
args[0]._orig_dtype,
56+
args[0]._mm_config,
57+
args[0]._scaling_strategy,
5458
)
5559

5660

@@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None):
6064

6165
def make_float8(data):
6266
return Float8Tensor(
63-
data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
67+
data,
68+
args[0]._scale,
69+
args[0]._orig_dtype,
70+
args[0]._mm_config,
71+
args[0]._scaling_strategy,
6472
)
6573

6674
out = map(make_float8, new_data_tensors)
@@ -75,6 +83,7 @@ def float8_cat(aten_op, args, kwargs=None):
7583
orig_dtype = chunked_tensors[0]._orig_dtype
7684
scale = chunked_tensors[0]._scale
7785
mm_config = chunked_tensors[0]._mm_config
86+
scaling_strategy = chunked_tensors[0]._scaling_strategy
7887
fp8_dtype = chunked_tensors[0]._data.dtype
7988
chunk_data = []
8089
for chunk in chunked_tensors:
@@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None):
93102
assert (
94103
chunk._data.dtype == fp8_dtype
95104
), "Expecting all chunks to be of the same dtype as a result of a split"
105+
assert (
106+
chunk._scaling_strategy is scaling_strategy
107+
), "Expecting all chunks to have thee same scaling strategy as a result of a split"
96108
chunk_data.append(chunk._data.view(torch.uint8))
97109

98110
new_data = aten_op(chunk_data, *args[1:], **kwargs)
99111
new_data = new_data.view(fp8_dtype)
100-
return Float8Tensor(new_data, scale, orig_dtype, mm_config)
112+
return Float8Tensor(new_data, scale, orig_dtype, mm_config, scaling_strategy)
101113

102114

103115
@implements([aten.sum.dim_IntList])
@@ -162,6 +174,11 @@ def float8_mm(aten_op, args, kwargs=None):
162174
return torch.ops.aten.mm_float8_emulated(
163175
a._data, a._scale, b._data, b._scale, output_dtype
164176
)
177+
scaling_strategy = a._scaling_strategy
178+
# TODO We can enable this by broadcasting to the more generic form
179+
assert (
180+
scaling_strategy == b._scaling_strategy
181+
), "Scaling strategy are currently required to be the same"
165182
tensor_out = addmm_float8_unwrapped(
166183
a_data,
167184
a_scale,
@@ -191,6 +208,11 @@ def float8_addmm(aten_op, args, kwargs=None):
191208
a_mm_config: ScaledMMConfig = a._mm_config
192209
b_mm_config: ScaledMMConfig = b._mm_config
193210
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
211+
scaling_strategy = a._scaling_strategy
212+
# TODO We can enable this by broadcasting to the more generic form
213+
assert (
214+
scaling_strategy == b._scaling_strategy
215+
), "Scaling strategy are currently required to be the same"
194216
if mm_config.emulate:
195217
out = torch.ops.aten.mm_float8_emulated(
196218
a._data, a._scale, b._data, b._scale, output_dtype
@@ -229,7 +251,11 @@ def autocast_to_copy(aten_op, args, kwargs=None):
229251
torch.bfloat16,
230252
}, "Only support floating point conversion for autocast w/ Float8Tensor"
231253
return Float8Tensor(
232-
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config
254+
args[0]._data,
255+
args[0]._scale,
256+
kwargs["dtype"],
257+
args[0]._mm_config,
258+
args[0]._scaling_strategy,
233259
)
234260

235261

@@ -252,7 +278,11 @@ def allgather_fp8(aten_op, args, kwargs=None):
252278
fp8_data = fp8_data.contiguous()
253279
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
254280
return Float8Tensor(
255-
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
281+
fp8_out,
282+
fp8_input._scale,
283+
fp8_input._orig_dtype,
284+
fp8_input._mm_config,
285+
fp8_input._scaling_strategy,
256286
)
257287

258288

@@ -264,7 +294,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
264294
fp8_data = fp8_input._data
265295
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
266296
return Float8Tensor(
267-
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
297+
fp8_out,
298+
fp8_input._scale,
299+
fp8_input._orig_dtype,
300+
fp8_input._mm_config,
301+
fp8_input._scaling_strategy,
268302
)
269303

270304

@@ -282,7 +316,11 @@ def index_put_fp8(aten_op, args, kwargs=None):
282316
fp8_values_data = fp8_values._data
283317
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
284318
return Float8Tensor(
285-
fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config
319+
fp8_out,
320+
fp8_self._scale,
321+
fp8_self._orig_dtype,
322+
fp8_self._mm_config,
323+
fp8_self._scaling_strategy,
286324
)
287325

288326

@@ -315,6 +353,12 @@ def copy_fp8(aten_op, args, kwargs=None):
315353
self._data.dtype == src._data.dtype
316354
), "Expecting both Float8Tensors to be of the same dtypet"
317355
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
318-
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
356+
return Float8Tensor(
357+
fp8_out,
358+
self._scale,
359+
self._orig_dtype,
360+
self._mm_config,
361+
self._scaling_strategy,
362+
)
319363
else:
320364
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")

0 commit comments

Comments
 (0)