Skip to content

Commit e8c5769

Browse files
committed
[float8] Allow specifying arbitrary dtype for each tensor
ghstack-source-id: d8300e2 ghstack-comment-id: 2517857809 Pull Request resolved: #1378
1 parent 11b2a2f commit e8c5769

11 files changed

+194
-108
lines changed

test/float8/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
from torchao.float8.config import (
2828
CastConfig,
29+
e4m3_dtype,
30+
e5m2_dtype,
2931
Float8LinearConfig,
3032
Float8LinearRecipeName,
3133
ScalingGranularity,
@@ -53,8 +55,6 @@
5355
from torchao.float8.float8_utils import (
5456
FP8_TYPES,
5557
compute_error,
56-
e4m3_dtype,
57-
e5m2_dtype,
5858
fp8_tensor_statistics,
5959
tensor_to_scale,
6060
)

test/float8/test_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from torchao.float8.config import (
2929
CastConfig,
30+
e4m3_dtype,
3031
Float8LinearConfig,
3132
Float8LinearRecipeName,
3233
ScalingType,
@@ -47,7 +48,6 @@
4748
LinearMMConfig,
4849
ScaledMMConfig,
4950
)
50-
from torchao.float8.float8_utils import e4m3_dtype
5151
from torchao.testing.float8.test_utils import get_test_float8_linear_config
5252

5353

test/float8/test_dtensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
from tqdm import tqdm
3232

3333
from torchao.float8 import Float8LinearConfig
34-
from torchao.float8.config import CastConfig, ScalingType
34+
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
3535
from torchao.float8.float8_linear_utils import convert_to_float8_training
36-
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
36+
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
3737
from torchao.float8.float8_tensor import (
3838
Float8Tensor,
3939
GemmInputRole,
@@ -45,7 +45,7 @@
4545
Float8RowwiseParallel,
4646
PrepareFloat8ModuleInput,
4747
)
48-
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
48+
from torchao.float8.float8_utils import tensor_to_scale
4949
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
5050
from torchao.testing.float8.dtensor_utils import ToyModel
5151

@@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
173173
)
174174

175175
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
176-
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
176+
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
177177
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
178178
loss = torch.sum(torch.abs(out - dist_target))
179179
loss.backward()

torchao/float8/config.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class CastConfig:
6262
scaling_type: ScalingType = ScalingType.DYNAMIC
6363
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
6464
static_scale: Optional[torch.Tensor] = None
65+
dtype: Optional[torch.dtype] = None
6566

6667
def short_str(self):
6768
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
@@ -75,6 +76,9 @@ def __post_init__(self):
7576
assert (
7677
self.scaling_type is ScalingType.DYNAMIC
7778
), "only dynamic scaling type is supported for axiswise scaling granularity"
79+
assert self.dtype is None or (
80+
self.dtype.is_floating_point and self.dtype.itemsize == 1
81+
), "must specify a 8-bit floating-point dtype"
7882

7983

8084
@dataclass(frozen=True)
@@ -124,6 +128,12 @@ def __post_init__(self):
124128
self.e5m2_dtype = torch.float8_e5m2fnuz
125129

126130

131+
# User defined type for using the individual F8 type based on config
132+
type_config = Float8TypeConfig()
133+
e4m3_dtype = type_config.e4m3_dtype
134+
e5m2_dtype = type_config.e5m2_dtype
135+
136+
127137
@dataclass(frozen=True)
128138
class Float8GemmConfig:
129139
"""
@@ -276,6 +286,20 @@ def __post_init__(self):
276286
is_disabled_1 == is_disabled_2
277287
), f"incompatible operand precision for {gemm_name}"
278288

289+
for cc1, cc2, operand_name, default_dtype in [
290+
(cc_i, cc_i_gw, "input", e4m3_dtype),
291+
(cc_w, cc_w_gi, "weight", e4m3_dtype),
292+
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
293+
]:
294+
# Override the dataclass being frozen
295+
if cc1.dtype is None:
296+
object.__setattr__(cc1, "dtype", default_dtype)
297+
if cc2.dtype is None:
298+
object.__setattr__(cc2, "dtype", default_dtype)
299+
assert (
300+
cc1.dtype == cc2.dtype
301+
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
302+
279303
if self.use_fp8_all_gather_only:
280304
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
281305

@@ -340,12 +364,14 @@ def recipe_name_to_linear_config(
340364
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
341365

342366
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
343-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
367+
cc_go = CastConfig(
368+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
369+
)
344370
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
345371

346372
# grad_weight_hp = input_t_hp @ grad_output_hp
347373
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
348-
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
374+
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype)
349375

350376
return Float8LinearConfig(
351377
cast_config_input=cc_i,

torchao/float8/float8_linear.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1616
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1717
from torchao.float8.float8_scaling_utils import (
18-
NoopFwToFloat8E5M2BwDelayed,
19-
NoopFwToFloat8E5M2BwDynamic,
20-
NoopFwToFloat8E5M2BwStatic,
18+
NoopFwToFloat8BwDelayed,
19+
NoopFwToFloat8BwDynamic,
20+
NoopFwToFloat8BwStatic,
2121
_maybe_initialize_amaxes_scales_for_float8_cast,
2222
get_maybe_axiswise_dim,
2323
hp_tensor_to_float8_delayed,
@@ -32,8 +32,6 @@
3232
hp_tensor_and_scale_to_float8,
3333
)
3434
from torchao.float8.float8_utils import (
35-
e4m3_dtype,
36-
e5m2_dtype,
3735
tensor_to_amax,
3836
tensor_to_scale,
3937
)
@@ -136,7 +134,7 @@ def forward(
136134
else:
137135
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
138136
input_hp,
139-
e4m3_dtype,
137+
c.cast_config_input.dtype,
140138
linear_mm_config,
141139
gemm_input_role=GemmInputRole.INPUT,
142140
scaling_granularity=c.cast_config_input.scaling_granularity,
@@ -150,7 +148,7 @@ def forward(
150148
else:
151149
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
152150
weight_hp_t,
153-
e4m3_dtype,
151+
c.cast_config_weight.dtype,
154152
linear_mm_config,
155153
gemm_input_role=GemmInputRole.WEIGHT,
156154
scaling_granularity=c.cast_config_weight.scaling_granularity,
@@ -186,7 +184,7 @@ def backward(ctx, grad_output):
186184
else:
187185
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
188186
grad_output_reshaped,
189-
e5m2_dtype,
187+
c.cast_config_grad_output.dtype,
190188
ctx.linear_mm_config,
191189
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
192190
scaling_granularity=c.cast_config_grad_output.scaling_granularity,
@@ -204,7 +202,7 @@ def backward(ctx, grad_output):
204202
# the entire tensor.
205203
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
206204
weight_hp_t,
207-
e4m3_dtype,
205+
c.cast_config_weight_for_grad_input.dtype,
208206
ctx.linear_mm_config,
209207
gemm_input_role=GemmInputRole.WEIGHT,
210208
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity,
@@ -236,7 +234,7 @@ def backward(ctx, grad_output):
236234
else:
237235
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
238236
grad_output_reshaped,
239-
e5m2_dtype,
237+
c.cast_config_grad_output_for_grad_weight.dtype,
240238
ctx.linear_mm_config,
241239
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
242240
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity,
@@ -250,7 +248,7 @@ def backward(ctx, grad_output):
250248
else:
251249
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
252250
input_hp_reshaped,
253-
e4m3_dtype,
251+
c.cast_config_input_for_grad_weight.dtype,
254252
ctx.linear_mm_config,
255253
gemm_input_role=GemmInputRole.INPUT,
256254
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity,
@@ -347,11 +345,9 @@ def create_buffers(self):
347345
# Default values for history buffers, see above TODO
348346
history_len = self.config.delayed_scaling_config.history_len
349347
device = self.weight.device
350-
# TODO(future PR): dtype values below don't have the other float8
351-
# flavors, fix it
352-
default_input = torch.finfo(torch.float8_e4m3fn).max
353-
default_weight = torch.finfo(torch.float8_e4m3fn).max
354-
default_grad_output = torch.finfo(torch.float8_e5m2).max
348+
default_input = torch.finfo(self.config.cast_config_input.dtype).max
349+
default_weight = torch.finfo(self.config.cast_config_weight.dtype).max
350+
default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max
355351

356352
# Note: for now, create all the buffers if any are needed, to postpone
357353
# the work to make the scale and amax syncing and history calculation
@@ -438,29 +434,32 @@ def cast_input_to_float8(
438434
self.fp8_amax_history_input,
439435
self.fp8_scale_input,
440436
scale_fn_name,
441-
e4m3_dtype,
437+
self.config.cast_config_input.dtype,
442438
is_amax_initialized,
443439
reduce_amax=True,
444440
)
445441
input_fp8 = hp_tensor_to_float8_delayed(
446442
input,
447443
self.fp8_scale_input,
448-
e4m3_dtype,
444+
self.config.cast_config_input.dtype,
449445
self.fp8_amax_input,
450446
linear_mm_config=self.linear_mm_config,
451447
gemm_input_role=GemmInputRole.INPUT,
452448
)
453449
elif self.scaling_type_input is ScalingType.DYNAMIC:
454450
input_fp8 = hp_tensor_to_float8_dynamic(
455451
input,
456-
e4m3_dtype,
452+
self.config.cast_config_input.dtype,
457453
self.linear_mm_config,
458454
gemm_input_role=GemmInputRole.INPUT,
459455
)
460456
else:
461457
assert self.scaling_type_input is ScalingType.STATIC
462458
input_fp8 = hp_tensor_to_float8_static(
463-
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
459+
input,
460+
self.fp8_static_scale_input,
461+
self.config.cast_config_input.dtype,
462+
self.linear_mm_config,
464463
)
465464

466465
return input_fp8
@@ -476,14 +475,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
476475
self.fp8_amax_history_weight,
477476
self.fp8_scale_weight,
478477
scale_fn_name,
479-
e4m3_dtype,
478+
self.config.cast_config_weight.dtype,
480479
self.is_amax_initialized,
481480
reduce_amax=True,
482481
)
483482
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
484483
return self.fp8_scale_weight
485484
elif self.scaling_type_weight is ScalingType.DYNAMIC:
486-
return tensor_to_scale(weight, e4m3_dtype)
485+
return tensor_to_scale(weight, self.config.cast_config_weight.dtype)
487486
else:
488487
assert self.scaling_type_weight is ScalingType.STATIC
489488
return self.fp8_static_scale_weight
@@ -499,7 +498,7 @@ def cast_weight_to_float8_t(
499498
weight_fp8 = hp_tensor_and_scale_to_float8(
500499
weight,
501500
weight_scale,
502-
e4m3_dtype,
501+
self.config.cast_config_weight.dtype,
503502
self.linear_mm_config,
504503
gemm_input_role=GemmInputRole.WEIGHT,
505504
)
@@ -514,23 +513,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
514513
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
515514
if self.scaling_type_grad_output is ScalingType.DELAYED:
516515
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
517-
output = NoopFwToFloat8E5M2BwDelayed.apply(
516+
output = NoopFwToFloat8BwDelayed.apply(
518517
output,
519518
self.fp8_amax_grad_output,
520519
self.fp8_amax_history_grad_output,
521520
self.fp8_scale_grad_output,
522521
scale_fn_name,
523522
self.is_amax_initialized,
524523
self.linear_mm_config,
524+
self.config.cast_config_grad_output.dtype,
525525
)
526526
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
527-
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
527+
output = NoopFwToFloat8BwDynamic.apply(
528+
output,
529+
self.linear_mm_config,
530+
self.config.cast_config_grad_output.dtype,
531+
)
528532
else:
529533
assert self.scaling_type_grad_output is ScalingType.STATIC
530-
output = NoopFwToFloat8E5M2BwStatic.apply(
534+
output = NoopFwToFloat8BwStatic.apply(
531535
output,
532536
self.fp8_static_scale_grad_output,
533537
self.linear_mm_config,
538+
self.config.cast_config_grad_output.dtype,
534539
)
535540
return output
536541

@@ -547,19 +552,16 @@ def float8_post_forward(self):
547552
return
548553

549554
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
550-
has_any_axiswise_scaling = (
551-
self.config.cast_config_input.scaling_granularity
552-
is ScalingGranularity.AXISWISE
553-
or self.config.cast_config_weight.scaling_granularity
554-
is ScalingGranularity.AXISWISE
555-
or self.config.cast_config_grad_output.scaling_granularity
556-
is ScalingGranularity.AXISWISE
557-
or self.config.cast_config_input_for_grad_weight.scaling_granularity
558-
is ScalingGranularity.AXISWISE
559-
or self.config.cast_config_weight_for_grad_input.scaling_granularity
560-
is ScalingGranularity.AXISWISE
561-
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity
562-
is ScalingGranularity.AXISWISE
555+
has_any_axiswise_scaling = any(
556+
cc.scaling_granularity is ScalingGranularity.AXISWISE
557+
for cc in [
558+
self.config.cast_config_input,
559+
self.config.cast_config_weight,
560+
self.config.cast_config_grad_output,
561+
self.config.cast_config_input_for_grad_weight,
562+
self.config.cast_config_weight_for_grad_input,
563+
self.config.cast_config_grad_output_for_grad_weight,
564+
]
563565
)
564566

565567
if not has_any_axiswise_scaling:
@@ -682,6 +684,7 @@ def from_float(
682684
WeightWithDynamicFloat8CastTensor(
683685
new_mod.weight,
684686
new_mod.linear_mm_config,
687+
new_mod.config.cast_config_weight.dtype,
685688
)
686689
)
687690
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
@@ -692,6 +695,7 @@ def from_float(
692695
new_mod.fp8_amax_history_weight,
693696
new_mod.fp8_scale_weight,
694697
new_mod.linear_mm_config,
698+
new_mod.config.cast_config_weight.dtype,
695699
new_mod.is_amax_initialized,
696700
)
697701
)
@@ -702,6 +706,7 @@ def from_float(
702706
new_mod.weight,
703707
new_mod.fp8_static_scale_weight,
704708
new_mod.linear_mm_config,
709+
new_mod.config.cast_config_weight.dtype,
705710
)
706711
)
707712

0 commit comments

Comments
 (0)