Skip to content

Commit 7d28acf

Browse files
committed
[float8] Allow specifying arbitrary dtype for each tensor
ghstack-source-id: 4b3a2f0 Pull Request resolved: #1326
1 parent bc0a29a commit 7d28acf

11 files changed

+197
-108
lines changed

test/float8/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from torchao.float8.config import (
2727
CastConfig,
28+
e4m3_dtype,
29+
e5m2_dtype,
2830
Float8LinearConfig,
2931
Float8LinearRecipeName,
3032
recipe_name_to_linear_config,
@@ -51,8 +53,6 @@
5153
)
5254
from torchao.float8.float8_utils import (
5355
compute_error,
54-
e4m3_dtype,
55-
e5m2_dtype,
5656
fp8_tensor_statistics,
5757
FP8_TYPES,
5858
tensor_to_scale,

test/float8/test_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn as nn
2222
from torchao.float8.config import (
2323
CastConfig,
24+
e4m3_dtype,
2425
Float8LinearConfig,
2526
ScalingType,
2627
Float8LinearRecipeName,
@@ -41,7 +42,6 @@
4142
GemmInputRole,
4243
ScaledMMConfig,
4344
)
44-
from torchao.float8.float8_utils import e4m3_dtype
4545
from torchao.testing.float8.test_utils import get_test_float8_linear_config
4646

4747
from torch._dynamo.test_case import TestCase as DynamoTestCase

test/float8/test_dtensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torchao.float8 import Float8LinearConfig
2828
from torchao.float8.float8_linear_utils import convert_to_float8_training
2929

30-
from torchao.float8.config import CastConfig, ScalingType
31-
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
30+
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
31+
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
3232
from torchao.float8.float8_tensor import (
3333
Float8Tensor,
3434
GemmInputRole,
@@ -40,7 +40,7 @@
4040
Float8RowwiseParallel,
4141
PrepareFloat8ModuleInput,
4242
)
43-
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
43+
from torchao.float8.float8_utils import tensor_to_scale
4444
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
4545
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
4646
from torch.distributed.tensor.parallel import parallelize_module
@@ -197,7 +197,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
197197
)
198198

199199
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
200-
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
200+
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
201201
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
202202
loss = torch.sum(torch.abs(out - dist_target))
203203
loss.backward()

torchao/float8/config.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class CastConfig:
6262
scaling_type: ScalingType = ScalingType.DYNAMIC
6363
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
6464
static_scale: Optional[torch.Tensor] = None
65+
dtype: Optional[torch.dtype] = None
6566

6667
def short_str(self):
6768
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
@@ -75,6 +76,9 @@ def __post_init__(self):
7576
assert (
7677
self.scaling_type is ScalingType.DYNAMIC
7778
), "only dynamic scaling type is supported for axiswise scaling granularity"
79+
assert self.dtype is None or (
80+
self.dtype.is_floating_point and self.dtype.itemsize == 1
81+
), "must specify a 8-bit floating-point dtype"
7882

7983

8084
@dataclass(frozen=True)
@@ -124,6 +128,12 @@ def __post_init__(self):
124128
self.e5m2_dtype = torch.float8_e5m2fnuz
125129

126130

131+
# User defined type for using the individual F8 type based on config
132+
type_config = Float8TypeConfig()
133+
e4m3_dtype = type_config.e4m3_dtype
134+
e5m2_dtype = type_config.e5m2_dtype
135+
136+
127137
@dataclass(frozen=True)
128138
class Float8GemmConfig:
129139
"""
@@ -279,6 +289,20 @@ def __post_init__(self):
279289
is_disabled_1 == is_disabled_2
280290
), f"incompatible operand precision for {gemm_name}"
281291

292+
for cc1, cc2, operand_name, default_dtype in [
293+
(cc_i, cc_i_gw, "input", e4m3_dtype),
294+
(cc_w, cc_w_gi, "weight", e4m3_dtype),
295+
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
296+
]:
297+
# Override the dataclass being frozen
298+
if cc1.dtype is None:
299+
object.__setattr__(cc1, "dtype", default_dtype)
300+
if cc2.dtype is None:
301+
object.__setattr__(cc2, "dtype", default_dtype)
302+
assert (
303+
cc1.dtype == cc2.dtype
304+
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
305+
282306
if self.use_fp8_all_gather_only:
283307
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
284308

@@ -343,12 +367,14 @@ def recipe_name_to_linear_config(
343367
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
344368

345369
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
346-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
370+
cc_go = CastConfig(
371+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
372+
)
347373
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
348374

349375
# grad_weight_hp = input_t_hp @ grad_output_hp
350376
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
351-
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
377+
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype)
352378

353379
return Float8LinearConfig(
354380
cast_config_input=cc_i,

torchao/float8/float8_linear.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1616
from torchao.float8.float8_scaling_utils import (
17-
NoopFwToFloat8E5M2BwDelayed,
18-
NoopFwToFloat8E5M2BwDynamic,
19-
NoopFwToFloat8E5M2BwStatic,
17+
NoopFwToFloat8BwDelayed,
18+
NoopFwToFloat8BwDynamic,
19+
NoopFwToFloat8BwStatic,
2020
_maybe_initialize_amaxes_scales_for_float8_cast,
2121
get_maybe_axiswise_dim,
2222
hp_tensor_to_float8_delayed,
@@ -31,8 +31,6 @@
3131
hp_tensor_and_scale_to_float8,
3232
)
3333
from torchao.float8.float8_utils import (
34-
e4m3_dtype,
35-
e5m2_dtype,
3634
tensor_to_amax,
3735
tensor_to_scale,
3836
)
@@ -135,7 +133,7 @@ def forward(
135133
else:
136134
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
137135
input_hp,
138-
e4m3_dtype,
136+
c.cast_config_input.dtype,
139137
linear_mm_config,
140138
gemm_input_role=GemmInputRole.INPUT,
141139
scaling_granularity=c.cast_config_input.scaling_granularity,
@@ -149,7 +147,7 @@ def forward(
149147
else:
150148
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
151149
weight_hp_t,
152-
e4m3_dtype,
150+
c.cast_config_weight.dtype,
153151
linear_mm_config,
154152
gemm_input_role=GemmInputRole.WEIGHT,
155153
scaling_granularity=c.cast_config_weight.scaling_granularity,
@@ -185,7 +183,7 @@ def backward(ctx, grad_output):
185183
else:
186184
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
187185
grad_output_reshaped,
188-
e5m2_dtype,
186+
c.cast_config_grad_output.dtype,
189187
ctx.linear_mm_config,
190188
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
191189
scaling_granularity=c.cast_config_grad_output.scaling_granularity,
@@ -203,7 +201,7 @@ def backward(ctx, grad_output):
203201
# the entire tensor.
204202
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
205203
weight_hp_t,
206-
e4m3_dtype,
204+
c.cast_config_weight_for_grad_input.dtype,
207205
ctx.linear_mm_config,
208206
gemm_input_role=GemmInputRole.WEIGHT,
209207
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity,
@@ -235,7 +233,7 @@ def backward(ctx, grad_output):
235233
else:
236234
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
237235
grad_output_reshaped,
238-
e5m2_dtype,
236+
c.cast_config_grad_output_for_grad_weight.dtype,
239237
ctx.linear_mm_config,
240238
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
241239
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity,
@@ -249,7 +247,7 @@ def backward(ctx, grad_output):
249247
else:
250248
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
251249
input_hp_reshaped,
252-
e4m3_dtype,
250+
c.cast_config_input_for_grad_weight.dtype,
253251
ctx.linear_mm_config,
254252
gemm_input_role=GemmInputRole.INPUT,
255253
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity,
@@ -354,11 +352,9 @@ def create_buffers(self):
354352
# Default values for history buffers, see above TODO
355353
history_len = self.config.delayed_scaling_config.history_len
356354
device = self.weight.device
357-
# TODO(future PR): dtype values below don't have the other float8
358-
# flavors, fix it
359-
default_input = torch.finfo(torch.float8_e4m3fn).max
360-
default_weight = torch.finfo(torch.float8_e4m3fn).max
361-
default_grad_output = torch.finfo(torch.float8_e5m2).max
355+
default_input = torch.finfo(self.config.cast_config_input.dtype).max
356+
default_weight = torch.finfo(self.config.cast_config_weight.dtype).max
357+
default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max
362358

363359
# Note: for now, create all the buffers if any are needed, to postpone
364360
# the work to make the scale and amax syncing and history calculation
@@ -445,29 +441,32 @@ def cast_input_to_float8(
445441
self.fp8_amax_history_input,
446442
self.fp8_scale_input,
447443
scale_fn_name,
448-
e4m3_dtype,
444+
self.config.cast_config_input.dtype,
449445
is_amax_initialized,
450446
reduce_amax=True,
451447
)
452448
input_fp8 = hp_tensor_to_float8_delayed(
453449
input,
454450
self.fp8_scale_input,
455-
e4m3_dtype,
451+
self.config.cast_config_input.dtype,
456452
self.fp8_amax_input,
457453
linear_mm_config=self.linear_mm_config,
458454
gemm_input_role=GemmInputRole.INPUT,
459455
)
460456
elif self.scaling_type_input is ScalingType.DYNAMIC:
461457
input_fp8 = hp_tensor_to_float8_dynamic(
462458
input,
463-
e4m3_dtype,
459+
self.config.cast_config_input.dtype,
464460
self.linear_mm_config,
465461
gemm_input_role=GemmInputRole.INPUT,
466462
)
467463
else:
468464
assert self.scaling_type_input is ScalingType.STATIC
469465
input_fp8 = hp_tensor_to_float8_static(
470-
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
466+
input,
467+
self.fp8_static_scale_input,
468+
self.config.cast_config_input.dtype,
469+
self.linear_mm_config,
471470
)
472471

473472
return input_fp8
@@ -483,14 +482,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
483482
self.fp8_amax_history_weight,
484483
self.fp8_scale_weight,
485484
scale_fn_name,
486-
e4m3_dtype,
485+
self.config.cast_config_weight.dtype,
487486
self.is_amax_initialized,
488487
reduce_amax=True,
489488
)
490489
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
491490
return self.fp8_scale_weight
492491
elif self.scaling_type_weight is ScalingType.DYNAMIC:
493-
return tensor_to_scale(weight, e4m3_dtype)
492+
return tensor_to_scale(weight, self.config.cast_config_weight.dtype)
494493
else:
495494
assert self.scaling_type_weight is ScalingType.STATIC
496495
return self.fp8_static_scale_weight
@@ -506,7 +505,7 @@ def cast_weight_to_float8_t(
506505
weight_fp8 = hp_tensor_and_scale_to_float8(
507506
weight,
508507
weight_scale,
509-
e4m3_dtype,
508+
self.config.cast_config_weight.dtype,
510509
self.linear_mm_config,
511510
gemm_input_role=GemmInputRole.WEIGHT,
512511
)
@@ -521,23 +520,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
521520
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
522521
if self.scaling_type_grad_output is ScalingType.DELAYED:
523522
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
524-
output = NoopFwToFloat8E5M2BwDelayed.apply(
523+
output = NoopFwToFloat8BwDelayed.apply(
525524
output,
526525
self.fp8_amax_grad_output,
527526
self.fp8_amax_history_grad_output,
528527
self.fp8_scale_grad_output,
529528
scale_fn_name,
530529
self.is_amax_initialized,
531530
self.linear_mm_config,
531+
self.config.cast_config_grad_output.dtype,
532532
)
533533
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
534-
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
534+
output = NoopFwToFloat8BwDynamic.apply(
535+
output,
536+
self.linear_mm_config,
537+
self.config.cast_config_grad_output.dtype,
538+
)
535539
else:
536540
assert self.scaling_type_grad_output is ScalingType.STATIC
537-
output = NoopFwToFloat8E5M2BwStatic.apply(
541+
output = NoopFwToFloat8BwStatic.apply(
538542
output,
539543
self.fp8_static_scale_grad_output,
540544
self.linear_mm_config,
545+
self.config.cast_config_grad_output.dtype,
541546
)
542547
return output
543548

@@ -563,19 +568,16 @@ def float8_post_forward(self):
563568
self.amax_and_scale_synced = False
564569

565570
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
566-
has_any_axiswise_scaling = (
567-
self.config.cast_config_input.scaling_granularity
568-
is ScalingGranularity.AXISWISE
569-
or self.config.cast_config_weight.scaling_granularity
570-
is ScalingGranularity.AXISWISE
571-
or self.config.cast_config_grad_output.scaling_granularity
572-
is ScalingGranularity.AXISWISE
573-
or self.config.cast_config_input_for_grad_weight.scaling_granularity
574-
is ScalingGranularity.AXISWISE
575-
or self.config.cast_config_weight_for_grad_input.scaling_granularity
576-
is ScalingGranularity.AXISWISE
577-
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity
578-
is ScalingGranularity.AXISWISE
571+
has_any_axiswise_scaling = any(
572+
cc.scaling_granularity is ScalingGranularity.AXISWISE
573+
for cc in [
574+
self.config.cast_config_input,
575+
self.config.cast_config_weight,
576+
self.config.cast_config_grad_output,
577+
self.config.cast_config_input_for_grad_weight,
578+
self.config.cast_config_weight_for_grad_input,
579+
self.config.cast_config_grad_output_for_grad_weight,
580+
]
579581
)
580582

581583
if not has_any_axiswise_scaling:
@@ -698,6 +700,7 @@ def from_float(
698700
WeightWithDynamicFloat8CastTensor(
699701
new_mod.weight,
700702
new_mod.linear_mm_config,
703+
new_mod.config.cast_config_weight.dtype,
701704
)
702705
)
703706
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
@@ -708,6 +711,7 @@ def from_float(
708711
new_mod.fp8_amax_history_weight,
709712
new_mod.fp8_scale_weight,
710713
new_mod.linear_mm_config,
714+
new_mod.config.cast_config_weight.dtype,
711715
new_mod.is_amax_initialized,
712716
)
713717
)
@@ -718,6 +722,7 @@ def from_float(
718722
new_mod.weight,
719723
new_mod.fp8_static_scale_weight,
720724
new_mod.linear_mm_config,
725+
new_mod.config.cast_config_weight.dtype,
721726
)
722727
)
723728

0 commit comments

Comments
 (0)