Skip to content

Commit a749a5f

Browse files
committed
[float8] Allow specifying arbitrary dtype for each tensor
ghstack-source-id: aa9f551 Pull Request resolved: #1326
1 parent 8e36b11 commit a749a5f

11 files changed

+208
-111
lines changed

test/float8/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from torchao.float8.config import (
2727
CastConfig,
28+
e4m3_dtype,
29+
e5m2_dtype,
2830
Float8LinearConfig,
2931
Float8LinearRecipeName,
3032
recipe_name_to_linear_config,
@@ -51,8 +53,6 @@
5153
)
5254
from torchao.float8.float8_utils import (
5355
compute_error,
54-
e4m3_dtype,
55-
e5m2_dtype,
5656
fp8_tensor_statistics,
5757
FP8_TYPES,
5858
tensor_to_scale,

test/float8/test_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn as nn
2222
from torchao.float8.config import (
2323
CastConfig,
24+
e4m3_dtype,
2425
Float8LinearConfig,
2526
ScalingType,
2627
Float8LinearRecipeName,
@@ -41,7 +42,6 @@
4142
GemmInputRole,
4243
ScaledMMConfig,
4344
)
44-
from torchao.float8.float8_utils import e4m3_dtype
4545
from torchao.testing.float8.test_utils import get_test_float8_linear_config
4646

4747
from torch._dynamo.test_case import TestCase as DynamoTestCase

test/float8/test_dtensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torchao.float8 import Float8LinearConfig
2828
from torchao.float8.float8_linear_utils import convert_to_float8_training
2929

30-
from torchao.float8.config import CastConfig, ScalingType
31-
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
30+
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
31+
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
3232
from torchao.float8.float8_tensor import (
3333
Float8Tensor,
3434
GemmInputRole,
@@ -40,7 +40,7 @@
4040
Float8RowwiseParallel,
4141
PrepareFloat8ModuleInput,
4242
)
43-
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
43+
from torchao.float8.float8_utils import tensor_to_scale
4444
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
4545
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
4646
from torch.distributed.tensor.parallel import parallelize_module
@@ -197,7 +197,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
197197
)
198198

199199
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
200-
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
200+
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
201201
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
202202
loss = torch.sum(torch.abs(out - dist_target))
203203
loss.backward()

torchao/float8/config.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class CastConfig:
6262
scaling_type: ScalingType = ScalingType.DYNAMIC
6363
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
6464
static_scale: Optional[torch.Tensor] = None
65+
dtype: torch.dtype = torch.uint8 # dummy dtype to satisfy typing
6566

6667
def short_str(self):
6768
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
@@ -75,6 +76,10 @@ def __post_init__(self):
7576
assert (
7677
self.scaling_type is ScalingType.DYNAMIC
7778
), "only dynamic scaling type is supported for axiswise scaling granularity"
79+
if self.scaling_type is not ScalingType.DISABLED:
80+
assert (
81+
self.dtype.is_floating_point and self.dtype.itemsize == 1
82+
), "must specify a 8-bit floating-point dtype"
7883

7984

8085
@dataclass(frozen=True)
@@ -124,6 +129,12 @@ def __post_init__(self):
124129
self.e5m2_dtype = torch.float8_e5m2fnuz
125130

126131

132+
# User defined type for using the individual F8 type based on config
133+
type_config = Float8TypeConfig()
134+
e4m3_dtype = type_config.e4m3_dtype
135+
e5m2_dtype = type_config.e5m2_dtype
136+
137+
127138
@dataclass(frozen=True)
128139
class Float8GemmConfig:
129140
"""
@@ -158,13 +169,13 @@ class Float8LinearConfig:
158169
# 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`.
159170
#
160171
# `input`
161-
cast_config_input: CastConfig = CastConfig()
172+
cast_config_input: CastConfig = CastConfig(dtype=e4m3_dtype)
162173
cast_config_input_for_grad_weight: Optional[CastConfig] = None
163174
# `weight`
164-
cast_config_weight: CastConfig = CastConfig()
175+
cast_config_weight: CastConfig = CastConfig(dtype=e4m3_dtype)
165176
cast_config_weight_for_grad_input: Optional[CastConfig] = None
166177
# `grad_output`
167-
cast_config_grad_output: CastConfig = CastConfig()
178+
cast_config_grad_output: CastConfig = CastConfig(dtype=e5m2_dtype)
168179
cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None
169180

170181
#
@@ -279,6 +290,15 @@ def __post_init__(self):
279290
is_disabled_1 == is_disabled_2
280291
), f"incompatible operand precision for {gemm_name}"
281292

293+
for cc1, cc2, operand_name in [
294+
(cc_i, cc_i_gw, "input"),
295+
(cc_w, cc_w_gi, "weight"),
296+
(cc_go, cc_go_gw, "grad_output"),
297+
]:
298+
assert (
299+
cc1.dtype == cc2.dtype
300+
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
301+
282302
if self.use_fp8_all_gather_only:
283303
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
284304

@@ -315,9 +335,15 @@ def recipe_name_to_linear_config(
315335

316336
elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
317337
# dynamic axiswise scaling with the CUTLASS rowwise kernel
318-
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
319-
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
320-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
338+
cc_i = CastConfig(
339+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
340+
)
341+
cc_w = CastConfig(
342+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
343+
)
344+
cc_go = CastConfig(
345+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e5m2_dtype
346+
)
321347

322348
return Float8LinearConfig(
323349
cast_config_input=cc_i,
@@ -339,12 +365,20 @@ def recipe_name_to_linear_config(
339365
# which is more amenable to fast kernels
340366

341367
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
342-
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
343-
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
368+
cc_i = CastConfig(
369+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
370+
)
371+
cc_w = CastConfig(
372+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
373+
)
344374

345375
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
346-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
347-
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
376+
cc_go = CastConfig(
377+
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
378+
)
379+
cc_w_gi = CastConfig(
380+
scaling_granularity=ScalingGranularity.TENSORWISE, dtype=e4m3_dtype
381+
)
348382

349383
# grad_weight_hp = input_t_hp @ grad_output_hp
350384
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)

0 commit comments

Comments
 (0)