Skip to content

Commit 28949f8

Browse files
committed
[float8] Allow specifying arbitrary dtype for each tensor
ghstack-source-id: 7dabc91 Pull Request resolved: #1326
1 parent 670b7da commit 28949f8

9 files changed

+161
-101
lines changed

test/float8/test_dtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torchao.float8.float8_linear_utils import convert_to_float8_training
2929

3030
from torchao.float8.config import CastConfig, ScalingType
31-
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
31+
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
3232
from torchao.float8.float8_tensor import (
3333
Float8Tensor,
3434
GemmInputRole,
@@ -197,7 +197,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
197197
)
198198

199199
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
200-
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
200+
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
201201
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
202202
loss = torch.sum(torch.abs(out - dist_target))
203203
loss.backward()

torchao/float8/config.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class CastConfig:
6262
scaling_type: ScalingType = ScalingType.DYNAMIC
6363
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
6464
static_scale: Optional[torch.Tensor] = None
65+
dtype: torch.dtype = torch.uint8 # dummy dtype to satisfy typing
6566

6667
def short_str(self):
6768
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
@@ -75,6 +76,10 @@ def __post_init__(self):
7576
assert (
7677
self.scaling_type is ScalingType.DYNAMIC
7778
), "only dynamic scaling type is supported for axiswise scaling granularity"
79+
if self.scaling_type is not ScalingType.DISABLED:
80+
assert (
81+
self.dtype.is_floating_point and self.dtype.itemsize == 1
82+
), "must specify a 8-bit floating-point dtype"
7883

7984

8085
@dataclass(frozen=True)
@@ -124,6 +129,12 @@ def __post_init__(self):
124129
self.e5m2_dtype = torch.float8_e5m2fnuz
125130

126131

132+
# User defined type for using the individual F8 type based on config
133+
type_config = Float8TypeConfig()
134+
e4m3_dtype = type_config.e4m3_dtype
135+
e5m2_dtype = type_config.e5m2_dtype
136+
137+
127138
@dataclass(frozen=True)
128139
class Float8GemmConfig:
129140
"""
@@ -158,13 +169,13 @@ class Float8LinearConfig:
158169
# 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`.
159170
#
160171
# `input`
161-
cast_config_input: CastConfig = CastConfig()
172+
cast_config_input: CastConfig = CastConfig(dtype=e4m3_dtype)
162173
cast_config_input_for_grad_weight: Optional[CastConfig] = None
163174
# `weight`
164-
cast_config_weight: CastConfig = CastConfig()
175+
cast_config_weight: CastConfig = CastConfig(dtype=e4m3_dtype)
165176
cast_config_weight_for_grad_input: Optional[CastConfig] = None
166177
# `grad_output`
167-
cast_config_grad_output: CastConfig = CastConfig()
178+
cast_config_grad_output: CastConfig = CastConfig(dtype=e5m2_dtype)
168179
cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None
169180

170181
#
@@ -279,6 +290,15 @@ def __post_init__(self):
279290
is_disabled_1 == is_disabled_2
280291
), f"incompatible operand precision for {gemm_name}"
281292

293+
for cc1, cc2, operand_name in [
294+
(cc_i, cc_i_gw, "input"),
295+
(cc_w, cc_w_gi, "weight"),
296+
(cc_go, cc_go_gw, "grad_output"),
297+
]:
298+
assert (
299+
cc1.dtype == cc2.dtype
300+
), f"{operand_name} must be cast to the same dtype in both the matmuls it's used in"
301+
282302
if self.use_fp8_all_gather_only:
283303
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
284304

@@ -315,9 +335,9 @@ def recipe_name_to_linear_config(
315335

316336
elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
317337
# dynamic axiswise scaling with the CUTLASS rowwise kernel
318-
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
319-
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
320-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
338+
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype)
339+
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype)
340+
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e5m2_dtype)
321341

322342
return Float8LinearConfig(
323343
cast_config_input=cc_i,
@@ -339,12 +359,12 @@ def recipe_name_to_linear_config(
339359
# which is more amenable to fast kernels
340360

341361
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
342-
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
343-
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
362+
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype)
363+
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype)
344364

345365
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
346-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
347-
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
366+
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype)
367+
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE, dtype=e4m3_dtype)
348368

349369
# grad_weight_hp = input_t_hp @ grad_output_hp
350370
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)

torchao/float8/float8_linear.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1616
from torchao.float8.float8_scaling_utils import (
17-
NoopFwToFloat8E5M2BwDelayed,
18-
NoopFwToFloat8E5M2BwDynamic,
19-
NoopFwToFloat8E5M2BwStatic,
17+
NoopFwToFloat8BwDelayed,
18+
NoopFwToFloat8BwDynamic,
19+
NoopFwToFloat8BwStatic,
2020
_maybe_initialize_amaxes_scales_for_float8_cast,
2121
get_maybe_axiswise_dim,
2222
hp_tensor_to_float8_delayed,
@@ -31,8 +31,6 @@
3131
hp_tensor_and_scale_to_float8,
3232
)
3333
from torchao.float8.float8_utils import (
34-
e4m3_dtype,
35-
e5m2_dtype,
3634
tensor_to_amax,
3735
tensor_to_scale,
3836
)
@@ -135,7 +133,7 @@ def forward(
135133
else:
136134
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
137135
input_hp,
138-
e4m3_dtype,
136+
c.cast_config_input.dtype,
139137
linear_mm_config,
140138
gemm_input_role=GemmInputRole.INPUT,
141139
scaling_granularity=c.cast_config_input.scaling_granularity,
@@ -149,7 +147,7 @@ def forward(
149147
else:
150148
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
151149
weight_hp_t,
152-
e4m3_dtype,
150+
c.cast_config_weight.dtype,
153151
linear_mm_config,
154152
gemm_input_role=GemmInputRole.WEIGHT,
155153
scaling_granularity=c.cast_config_weight.scaling_granularity,
@@ -185,7 +183,7 @@ def backward(ctx, grad_output):
185183
else:
186184
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
187185
grad_output_reshaped,
188-
e5m2_dtype,
186+
c.cast_config_grad_output.dtype,
189187
ctx.linear_mm_config,
190188
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
191189
scaling_granularity=c.cast_config_grad_output.scaling_granularity,
@@ -203,7 +201,7 @@ def backward(ctx, grad_output):
203201
# the entire tensor.
204202
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
205203
weight_hp_t,
206-
e4m3_dtype,
204+
c.cast_config_weight_for_grad_input.dtype,
207205
ctx.linear_mm_config,
208206
gemm_input_role=GemmInputRole.WEIGHT,
209207
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity,
@@ -235,7 +233,7 @@ def backward(ctx, grad_output):
235233
else:
236234
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
237235
grad_output_reshaped,
238-
e5m2_dtype,
236+
c.cast_config_grad_output_for_grad_weight.dtype,
239237
ctx.linear_mm_config,
240238
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
241239
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity,
@@ -249,7 +247,7 @@ def backward(ctx, grad_output):
249247
else:
250248
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
251249
input_hp_reshaped,
252-
e4m3_dtype,
250+
c.cast_config_input_for_grad_weight.dtype,
253251
ctx.linear_mm_config,
254252
gemm_input_role=GemmInputRole.INPUT,
255253
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity,
@@ -354,11 +352,9 @@ def create_buffers(self):
354352
# Default values for history buffers, see above TODO
355353
history_len = self.config.delayed_scaling_config.history_len
356354
device = self.weight.device
357-
# TODO(future PR): dtype values below don't have the other float8
358-
# flavors, fix it
359-
default_input = torch.finfo(torch.float8_e4m3fn).max
360-
default_weight = torch.finfo(torch.float8_e4m3fn).max
361-
default_grad_output = torch.finfo(torch.float8_e5m2).max
355+
default_input = torch.finfo(config.cast_config_input.dtype).max
356+
default_weight = torch.finfo(config.cast_config_weight.dtype).max
357+
default_grad_output = torch.finfo(config.cast_config_grad_output.dtype).max
362358

363359
# Note: for now, create all the buffers if any are needed, to postpone
364360
# the work to make the scale and amax syncing and history calculation
@@ -445,29 +441,32 @@ def cast_input_to_float8(
445441
self.fp8_amax_history_input,
446442
self.fp8_scale_input,
447443
scale_fn_name,
448-
e4m3_dtype,
444+
self.config.cast_config_input.dtype,
449445
is_amax_initialized,
450446
reduce_amax=True,
451447
)
452448
input_fp8 = hp_tensor_to_float8_delayed(
453449
input,
454450
self.fp8_scale_input,
455-
e4m3_dtype,
451+
self.config.cast_config_input.dtype,
456452
self.fp8_amax_input,
457453
linear_mm_config=self.linear_mm_config,
458454
gemm_input_role=GemmInputRole.INPUT,
459455
)
460456
elif self.scaling_type_input is ScalingType.DYNAMIC:
461457
input_fp8 = hp_tensor_to_float8_dynamic(
462458
input,
463-
e4m3_dtype,
459+
self.config.cast_config_input.dtype,
464460
self.linear_mm_config,
465461
gemm_input_role=GemmInputRole.INPUT,
466462
)
467463
else:
468464
assert self.scaling_type_input is ScalingType.STATIC
469465
input_fp8 = hp_tensor_to_float8_static(
470-
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
466+
input,
467+
self.fp8_static_scale_input,
468+
self.config.cast_config_input.dtype,
469+
self.linear_mm_config,
471470
)
472471

473472
return input_fp8
@@ -483,14 +482,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
483482
self.fp8_amax_history_weight,
484483
self.fp8_scale_weight,
485484
scale_fn_name,
486-
e4m3_dtype,
485+
self.config.cast_config_weight.dtype,
487486
self.is_amax_initialized,
488487
reduce_amax=True,
489488
)
490489
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
491490
return self.fp8_scale_weight
492491
elif self.scaling_type_weight is ScalingType.DYNAMIC:
493-
return tensor_to_scale(weight, e4m3_dtype)
492+
return tensor_to_scale(weight, self.config.cast_config_weight.dtype)
494493
else:
495494
assert self.scaling_type_weight is ScalingType.STATIC
496495
return self.fp8_static_scale_weight
@@ -506,7 +505,7 @@ def cast_weight_to_float8_t(
506505
weight_fp8 = hp_tensor_and_scale_to_float8(
507506
weight,
508507
weight_scale,
509-
e4m3_dtype,
508+
self.config.cast_config_weight.dtype,
510509
self.linear_mm_config,
511510
gemm_input_role=GemmInputRole.WEIGHT,
512511
)
@@ -521,23 +520,25 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
521520
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
522521
if self.scaling_type_grad_output is ScalingType.DELAYED:
523522
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
524-
output = NoopFwToFloat8E5M2BwDelayed.apply(
523+
output = NoopFwToFloat8BwDelayed.apply(
525524
output,
526525
self.fp8_amax_grad_output,
527526
self.fp8_amax_history_grad_output,
528527
self.fp8_scale_grad_output,
529528
scale_fn_name,
530529
self.is_amax_initialized,
531530
self.linear_mm_config,
531+
self.config.cast_config_grad_output.dtype,
532532
)
533533
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
534-
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
534+
output = NoopFwToFloat8BwDynamic.apply(output, self.linear_mm_config, self.config.cast_config_grad_output.dtype)
535535
else:
536536
assert self.scaling_type_grad_output is ScalingType.STATIC
537-
output = NoopFwToFloat8E5M2BwStatic.apply(
537+
output = NoopFwToFloat8BwStatic.apply(
538538
output,
539539
self.fp8_static_scale_grad_output,
540540
self.linear_mm_config,
541+
self.config.cast_config_grad_output.dtype,
541542
)
542543
return output
543544

@@ -563,19 +564,15 @@ def float8_post_forward(self):
563564
self.amax_and_scale_synced = False
564565

565566
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
566-
has_any_axiswise_scaling = (
567-
self.config.cast_config_input.scaling_granularity
568-
is ScalingGranularity.AXISWISE
569-
or self.config.cast_config_weight.scaling_granularity
570-
is ScalingGranularity.AXISWISE
571-
or self.config.cast_config_grad_output.scaling_granularity
572-
is ScalingGranularity.AXISWISE
573-
or self.config.cast_config_input_for_grad_weight.scaling_granularity
574-
is ScalingGranularity.AXISWISE
575-
or self.config.cast_config_weight_for_grad_input.scaling_granularity
576-
is ScalingGranularity.AXISWISE
577-
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity
578-
is ScalingGranularity.AXISWISE
567+
has_any_axiswise_scaling = any(
568+
cc.scaling_granularity is ScalingGranularity.AXISWISE for cc in [
569+
self.config.cast_config_input,
570+
self.config.cast_config_weight,
571+
self.config.cast_config_grad_output,
572+
self.config.cast_config_input_for_grad_weight,
573+
self.config.cast_config_weight_for_grad_input,
574+
self.config.cast_config_grad_output_for_grad_weight,
575+
]
579576
)
580577

581578
if not has_any_axiswise_scaling:
@@ -698,6 +695,7 @@ def from_float(
698695
WeightWithDynamicFloat8CastTensor(
699696
new_mod.weight,
700697
new_mod.linear_mm_config,
698+
new_mod.config.cast_config_weight.dtype,
701699
)
702700
)
703701
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
@@ -708,6 +706,7 @@ def from_float(
708706
new_mod.fp8_amax_history_weight,
709707
new_mod.fp8_scale_weight,
710708
new_mod.linear_mm_config,
709+
new_mod.config.cast_config_weight.dtype,
711710
new_mod.is_amax_initialized,
712711
)
713712
)
@@ -718,6 +717,7 @@ def from_float(
718717
new_mod.weight,
719718
new_mod.fp8_static_scale_weight,
720719
new_mod.linear_mm_config,
720+
new_mod.config.cast_config_weight.dtype,
721721
)
722722
)
723723

0 commit comments

Comments
 (0)