Skip to content

Commit 0d31d5a

Browse files
lwamdfaa
authored andcommitted
[float8] Allow specifying arbitrary dtype for each tensor (#1378)
1 parent a85c2bb commit 0d31d5a

11 files changed

+229
-133
lines changed

test/float8/test_base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
Float8LinearRecipeName,
3131
ScalingGranularity,
3232
ScalingType,
33+
e4m3_dtype,
34+
e5m2_dtype,
3335
recipe_name_to_linear_config,
3436
)
3537
from torchao.float8.float8_linear import Float8Linear
@@ -53,8 +55,6 @@
5355
from torchao.float8.float8_utils import (
5456
FP8_TYPES,
5557
compute_error,
56-
e4m3_dtype,
57-
e5m2_dtype,
5858
fp8_tensor_statistics,
5959
tensor_to_scale,
6060
)
@@ -546,7 +546,7 @@ def test_repr(self):
546546
config=config,
547547
)
548548
s = m.__repr__()
549-
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s
549+
assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s
550550

551551
@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
552552
def test_inference_mode(self):

test/float8/test_compile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Float8LinearConfig,
3131
Float8LinearRecipeName,
3232
ScalingType,
33+
e4m3_dtype,
3334
recipe_name_to_linear_config,
3435
)
3536
from torchao.float8.float8_linear import Float8Linear
@@ -47,7 +48,6 @@
4748
LinearMMConfig,
4849
ScaledMMConfig,
4950
)
50-
from torchao.float8.float8_utils import e4m3_dtype
5151
from torchao.testing.float8.test_utils import get_test_float8_linear_config
5252

5353

test/float8/test_dtensor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
from tqdm import tqdm
3232

3333
from torchao.float8 import Float8LinearConfig
34-
from torchao.float8.config import CastConfig, ScalingType
34+
from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype
3535
from torchao.float8.float8_linear_utils import convert_to_float8_training
36-
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
36+
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
3737
from torchao.float8.float8_tensor import (
3838
Float8Tensor,
3939
GemmInputRole,
@@ -45,7 +45,7 @@
4545
Float8RowwiseParallel,
4646
PrepareFloat8ModuleInput,
4747
)
48-
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
48+
from torchao.float8.float8_utils import tensor_to_scale
4949
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
5050
from torchao.testing.float8.dtensor_utils import ToyModel
5151

@@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
173173
)
174174

175175
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
176-
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
176+
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
177177
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
178178
loss = torch.sum(torch.abs(out - dist_target))
179179
loss.backward()

torchao/float8/config.py

+56-26
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,35 @@ def short_str(self):
5353
return "axs"
5454

5555

56+
@dataclass
57+
class Float8TypeConfig:
58+
"""
59+
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
60+
61+
Currently, ROCm only supports fnuz variants.
62+
"""
63+
64+
# The preferred e4m3 type.
65+
e4m3_dtype = torch.float8_e4m3fn
66+
67+
# The preferred e5m2 type.
68+
e5m2_dtype = torch.float8_e5m2
69+
70+
def __post_init__(self):
71+
if torch.version.hip and torch.cuda.is_available():
72+
prop = torch.cuda.get_device_properties(0)
73+
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
74+
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
75+
self.e4m3_dtype = torch.float8_e4m3fnuz
76+
self.e5m2_dtype = torch.float8_e5m2fnuz
77+
78+
79+
# User defined type for using the individual F8 type based on config
80+
type_config = Float8TypeConfig()
81+
e4m3_dtype = type_config.e4m3_dtype
82+
e5m2_dtype = type_config.e5m2_dtype
83+
84+
5685
@dataclass(frozen=True)
5786
class CastConfig:
5887
"""
@@ -62,9 +91,11 @@ class CastConfig:
6291
scaling_type: ScalingType = ScalingType.DYNAMIC
6392
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
6493
static_scale: Optional[torch.Tensor] = None
94+
target_dtype: Optional[torch.dtype] = None
6595

6696
def short_str(self):
67-
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
97+
dtype = {e4m3_dtype: "e4m3", e5m2_dtype: "e5m2"}[self.target_dtype]
98+
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}"
6899

69100
def __post_init__(self):
70101
if self.scaling_type is ScalingType.STATIC:
@@ -75,6 +106,9 @@ def __post_init__(self):
75106
assert (
76107
self.scaling_type is ScalingType.DYNAMIC
77108
), "only dynamic scaling type is supported for axiswise scaling granularity"
109+
assert self.target_dtype is None or (
110+
self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1
111+
), "must specify a 8-bit floating-point dtype"
78112

79113

80114
@dataclass(frozen=True)
@@ -101,29 +135,6 @@ def __post_init__(self):
101135
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
102136

103137

104-
@dataclass
105-
class Float8TypeConfig:
106-
"""
107-
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
108-
109-
Currently, ROCm only supports fnuz variants.
110-
"""
111-
112-
# The preferred e4m3 type.
113-
e4m3_dtype = torch.float8_e4m3fn
114-
115-
# The preferred e5m2 type.
116-
e5m2_dtype = torch.float8_e5m2
117-
118-
def __post_init__(self):
119-
if torch.version.hip and torch.cuda.is_available():
120-
prop = torch.cuda.get_device_properties(0)
121-
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
122-
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
123-
self.e4m3_dtype = torch.float8_e4m3fnuz
124-
self.e5m2_dtype = torch.float8_e5m2fnuz
125-
126-
127138
@dataclass(frozen=True)
128139
class Float8GemmConfig:
129140
"""
@@ -276,6 +287,20 @@ def __post_init__(self):
276287
is_disabled_1 == is_disabled_2
277288
), f"incompatible operand precision for {gemm_name}"
278289

290+
for cc1, cc2, operand_name, default_dtype in [
291+
(cc_i, cc_i_gw, "input", e4m3_dtype),
292+
(cc_w, cc_w_gi, "weight", e4m3_dtype),
293+
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
294+
]:
295+
# Override the dataclass being frozen
296+
if cc1.target_dtype is None:
297+
object.__setattr__(cc1, "target_dtype", default_dtype)
298+
if cc2.target_dtype is None:
299+
object.__setattr__(cc2, "target_dtype", default_dtype)
300+
assert (
301+
cc1.target_dtype == cc2.target_dtype
302+
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
303+
279304
if self.use_fp8_all_gather_only:
280305
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
281306

@@ -334,18 +359,23 @@ def recipe_name_to_linear_config(
334359
# * `input`, `weight` and `grad_output` now only need to be scaled
335360
# axiswise across a single dim compared to vanilla all-axiswise,
336361
# which is more amenable to fast kernels
362+
# * the e4m3 dtype is used across the board, including for gradients
337363

338364
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
339365
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
340366
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
341367

342368
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
343-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
369+
cc_go = CastConfig(
370+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
371+
)
344372
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
345373

346374
# grad_weight_hp = input_t_hp @ grad_output_hp
347375
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
348-
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
376+
cc_go_gw = CastConfig(
377+
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
378+
)
349379

350380
return Float8LinearConfig(
351381
cast_config_input=cc_i,

0 commit comments

Comments
 (0)