@@ -62,6 +62,7 @@ class CastConfig:
62
62
scaling_type : ScalingType = ScalingType .DYNAMIC
63
63
scaling_granularity : ScalingGranularity = ScalingGranularity .TENSORWISE
64
64
static_scale : Optional [torch .Tensor ] = None
65
+ dtype : torch .dtype = torch .uint8 # dummy dtype to satisfy typing
65
66
66
67
def short_str (self ):
67
68
return f"{ self .scaling_type .short_str ()} _{ self .scaling_granularity .short_str ()} "
@@ -75,6 +76,10 @@ def __post_init__(self):
75
76
assert (
76
77
self .scaling_type is ScalingType .DYNAMIC
77
78
), "only dynamic scaling type is supported for axiswise scaling granularity"
79
+ if self .scaling_type is not ScalingType .DISABLED :
80
+ assert (
81
+ self .dtype .is_floating_point and self .dtype .itemsize == 1
82
+ ), "must specify a 8-bit floating-point dtype"
78
83
79
84
80
85
@dataclass (frozen = True )
@@ -124,6 +129,12 @@ def __post_init__(self):
124
129
self .e5m2_dtype = torch .float8_e5m2fnuz
125
130
126
131
132
+ # User defined type for using the individual F8 type based on config
133
+ type_config = Float8TypeConfig ()
134
+ e4m3_dtype = type_config .e4m3_dtype
135
+ e5m2_dtype = type_config .e5m2_dtype
136
+
137
+
127
138
@dataclass (frozen = True )
128
139
class Float8GemmConfig :
129
140
"""
@@ -158,13 +169,13 @@ class Float8LinearConfig:
158
169
# 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`.
159
170
#
160
171
# `input`
161
- cast_config_input : CastConfig = CastConfig ()
172
+ cast_config_input : CastConfig = CastConfig (dtype = e4m3_dtype )
162
173
cast_config_input_for_grad_weight : Optional [CastConfig ] = None
163
174
# `weight`
164
- cast_config_weight : CastConfig = CastConfig ()
175
+ cast_config_weight : CastConfig = CastConfig (dtype = e4m3_dtype )
165
176
cast_config_weight_for_grad_input : Optional [CastConfig ] = None
166
177
# `grad_output`
167
- cast_config_grad_output : CastConfig = CastConfig ()
178
+ cast_config_grad_output : CastConfig = CastConfig (dtype = e5m2_dtype )
168
179
cast_config_grad_output_for_grad_weight : Optional [CastConfig ] = None
169
180
170
181
#
@@ -279,6 +290,15 @@ def __post_init__(self):
279
290
is_disabled_1 == is_disabled_2
280
291
), f"incompatible operand precision for { gemm_name } "
281
292
293
+ for cc1 , cc2 , operand_name in [
294
+ (cc_i , cc_i_gw , "input" ),
295
+ (cc_w , cc_w_gi , "weight" ),
296
+ (cc_go , cc_go_gw , "grad_output" ),
297
+ ]:
298
+ assert (
299
+ cc1 .dtype == cc2 .dtype
300
+ ), f"{ operand_name } must be cast to the same dtype in both matmuls it's used in"
301
+
282
302
if self .use_fp8_all_gather_only :
283
303
assert self .enable_fsdp_float8_all_gather , "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
284
304
@@ -315,9 +335,15 @@ def recipe_name_to_linear_config(
315
335
316
336
elif recipe_name is Float8LinearRecipeName .ALL_AXISWISE :
317
337
# dynamic axiswise scaling with the CUTLASS rowwise kernel
318
- cc_i = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
319
- cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
320
- cc_go = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
338
+ cc_i = CastConfig (
339
+ scaling_granularity = ScalingGranularity .AXISWISE , dtype = e4m3_dtype
340
+ )
341
+ cc_w = CastConfig (
342
+ scaling_granularity = ScalingGranularity .AXISWISE , dtype = e4m3_dtype
343
+ )
344
+ cc_go = CastConfig (
345
+ scaling_granularity = ScalingGranularity .AXISWISE , dtype = e5m2_dtype
346
+ )
321
347
322
348
return Float8LinearConfig (
323
349
cast_config_input = cc_i ,
@@ -339,12 +365,20 @@ def recipe_name_to_linear_config(
339
365
# which is more amenable to fast kernels
340
366
341
367
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
342
- cc_i = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
343
- cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
368
+ cc_i = CastConfig (
369
+ scaling_granularity = ScalingGranularity .AXISWISE , dtype = e4m3_dtype
370
+ )
371
+ cc_w = CastConfig (
372
+ scaling_granularity = ScalingGranularity .AXISWISE , dtype = e4m3_dtype
373
+ )
344
374
345
375
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
346
- cc_go = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
347
- cc_w_gi = CastConfig (scaling_granularity = ScalingGranularity .TENSORWISE )
376
+ cc_go = CastConfig (
377
+ scaling_granularity = ScalingGranularity .AXISWISE , dtype = e4m3_dtype
378
+ )
379
+ cc_w_gi = CastConfig (
380
+ scaling_granularity = ScalingGranularity .TENSORWISE , dtype = e4m3_dtype
381
+ )
348
382
349
383
# grad_weight_hp = input_t_hp @ grad_output_hp
350
384
cc_i_gw = CastConfig (scaling_type = ScalingType .DISABLED )
0 commit comments