@@ -53,6 +53,35 @@ def short_str(self):
53
53
return "axs"
54
54
55
55
56
+ @dataclass
57
+ class Float8TypeConfig :
58
+ """
59
+ Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
60
+
61
+ Currently, ROCm only supports fnuz variants.
62
+ """
63
+
64
+ # The preferred e4m3 type.
65
+ e4m3_dtype = torch .float8_e4m3fn
66
+
67
+ # The preferred e5m2 type.
68
+ e5m2_dtype = torch .float8_e5m2
69
+
70
+ def __post_init__ (self ):
71
+ if torch .version .hip and torch .cuda .is_available ():
72
+ prop = torch .cuda .get_device_properties (0 )
73
+ MI300_ARCH = ("gfx940" , "gfx941" , "gfx942" )
74
+ if prop .gcnArchName .split (":" )[0 ] in MI300_ARCH :
75
+ self .e4m3_dtype = torch .float8_e4m3fnuz
76
+ self .e5m2_dtype = torch .float8_e5m2fnuz
77
+
78
+
79
+ # User defined type for using the individual F8 type based on config
80
+ type_config = Float8TypeConfig ()
81
+ e4m3_dtype = type_config .e4m3_dtype
82
+ e5m2_dtype = type_config .e5m2_dtype
83
+
84
+
56
85
@dataclass (frozen = True )
57
86
class CastConfig :
58
87
"""
@@ -62,9 +91,11 @@ class CastConfig:
62
91
scaling_type : ScalingType = ScalingType .DYNAMIC
63
92
scaling_granularity : ScalingGranularity = ScalingGranularity .TENSORWISE
64
93
static_scale : Optional [torch .Tensor ] = None
94
+ target_dtype : Optional [torch .dtype ] = None
65
95
66
96
def short_str (self ):
67
- return f"{ self .scaling_type .short_str ()} _{ self .scaling_granularity .short_str ()} "
97
+ dtype = {e4m3_dtype : "e4m3" , e5m2_dtype : "e5m2" }[self .target_dtype ]
98
+ return f"{ self .scaling_type .short_str ()} _{ self .scaling_granularity .short_str ()} _{ dtype } "
68
99
69
100
def __post_init__ (self ):
70
101
if self .scaling_type is ScalingType .STATIC :
@@ -75,6 +106,9 @@ def __post_init__(self):
75
106
assert (
76
107
self .scaling_type is ScalingType .DYNAMIC
77
108
), "only dynamic scaling type is supported for axiswise scaling granularity"
109
+ assert self .target_dtype is None or (
110
+ self .target_dtype .is_floating_point and self .target_dtype .itemsize == 1
111
+ ), "must specify a 8-bit floating-point dtype"
78
112
79
113
80
114
@dataclass (frozen = True )
@@ -101,29 +135,6 @@ def __post_init__(self):
101
135
), f"{ self .scale_fn_name } is not implemented yet. Only max is supported for now."
102
136
103
137
104
- @dataclass
105
- class Float8TypeConfig :
106
- """
107
- Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
108
-
109
- Currently, ROCm only supports fnuz variants.
110
- """
111
-
112
- # The preferred e4m3 type.
113
- e4m3_dtype = torch .float8_e4m3fn
114
-
115
- # The preferred e5m2 type.
116
- e5m2_dtype = torch .float8_e5m2
117
-
118
- def __post_init__ (self ):
119
- if torch .version .hip and torch .cuda .is_available ():
120
- prop = torch .cuda .get_device_properties (0 )
121
- MI300_ARCH = ("gfx940" , "gfx941" , "gfx942" )
122
- if prop .gcnArchName .split (":" )[0 ] in MI300_ARCH :
123
- self .e4m3_dtype = torch .float8_e4m3fnuz
124
- self .e5m2_dtype = torch .float8_e5m2fnuz
125
-
126
-
127
138
@dataclass (frozen = True )
128
139
class Float8GemmConfig :
129
140
"""
@@ -276,6 +287,20 @@ def __post_init__(self):
276
287
is_disabled_1 == is_disabled_2
277
288
), f"incompatible operand precision for { gemm_name } "
278
289
290
+ for cc1 , cc2 , operand_name , default_dtype in [
291
+ (cc_i , cc_i_gw , "input" , e4m3_dtype ),
292
+ (cc_w , cc_w_gi , "weight" , e4m3_dtype ),
293
+ (cc_go , cc_go_gw , "grad_output" , e5m2_dtype ),
294
+ ]:
295
+ # Override the dataclass being frozen
296
+ if cc1 .target_dtype is None :
297
+ object .__setattr__ (cc1 , "target_dtype" , default_dtype )
298
+ if cc2 .target_dtype is None :
299
+ object .__setattr__ (cc2 , "target_dtype" , default_dtype )
300
+ assert (
301
+ cc1 .target_dtype == cc2 .target_dtype
302
+ ), f"{ operand_name } must be cast to the same dtype in both matmuls it's used in"
303
+
279
304
if self .use_fp8_all_gather_only :
280
305
assert self .enable_fsdp_float8_all_gather , "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
281
306
@@ -334,18 +359,23 @@ def recipe_name_to_linear_config(
334
359
# * `input`, `weight` and `grad_output` now only need to be scaled
335
360
# axiswise across a single dim compared to vanilla all-axiswise,
336
361
# which is more amenable to fast kernels
362
+ # * the e4m3 dtype is used across the board, including for gradients
337
363
338
364
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
339
365
cc_i = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
340
366
cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
341
367
342
368
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
343
- cc_go = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
369
+ cc_go = CastConfig (
370
+ scaling_granularity = ScalingGranularity .AXISWISE , target_dtype = e4m3_dtype
371
+ )
344
372
cc_w_gi = CastConfig (scaling_granularity = ScalingGranularity .TENSORWISE )
345
373
346
374
# grad_weight_hp = input_t_hp @ grad_output_hp
347
375
cc_i_gw = CastConfig (scaling_type = ScalingType .DISABLED )
348
- cc_go_gw = CastConfig (scaling_type = ScalingType .DISABLED )
376
+ cc_go_gw = CastConfig (
377
+ scaling_type = ScalingType .DISABLED , target_dtype = e4m3_dtype
378
+ )
349
379
350
380
return Float8LinearConfig (
351
381
cast_config_input = cc_i ,
0 commit comments