@@ -170,7 +170,6 @@ class Float8LinearConfig:
170
170
#
171
171
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
172
172
# `grad_weight`
173
- # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling
174
173
#
175
174
gemm_config_output : Float8GemmConfig = Float8GemmConfig (use_fast_accum = True )
176
175
gemm_config_grad_input : Float8GemmConfig = Float8GemmConfig ()
@@ -317,21 +316,10 @@ def recipe_name_to_linear_config(
317
316
cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
318
317
cc_go = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
319
318
320
- # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
321
- # fast with `use_fast_accum=True`. Note that rowwise scaling is more
322
- # accurate than tensorwise scaling, so the overall impact on accuracy
323
- # of tensorwise vs rowwise taking this flag into account will vary.
324
- gc_o = Float8GemmConfig (use_fast_accum = True )
325
- gc_gi = Float8GemmConfig (use_fast_accum = True )
326
- gc_gw = Float8GemmConfig (use_fast_accum = True )
327
-
328
319
return Float8LinearConfig (
329
320
cast_config_input = cc_i ,
330
321
cast_config_weight = cc_w ,
331
322
cast_config_grad_output = cc_go ,
332
- gemm_config_output = gc_o ,
333
- gemm_config_grad_input = gc_gi ,
334
- gemm_config_grad_weight = gc_gw ,
335
323
)
336
324
337
325
elif recipe_name is Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP :
@@ -359,24 +347,13 @@ def recipe_name_to_linear_config(
359
347
cc_i_gw = CastConfig (scaling_type = ScalingType .DISABLED )
360
348
cc_go_gw = CastConfig (scaling_type = ScalingType .DISABLED )
361
349
362
- # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
363
- # fast with `use_fast_accum=True`. Note that rowwise scaling is more
364
- # accurate than tensorwise scaling, so the overall impact on accuracy
365
- # of tensorwise vs rowwise taking this flag into account will vary.
366
- gc_o = Float8GemmConfig (use_fast_accum = True )
367
- gc_gi = Float8GemmConfig (use_fast_accum = True )
368
- gc_gw = Float8GemmConfig (use_fast_accum = True )
369
-
370
350
return Float8LinearConfig (
371
351
cast_config_input = cc_i ,
372
352
cast_config_weight = cc_w ,
373
353
cast_config_grad_output = cc_go ,
374
354
cast_config_input_for_grad_weight = cc_i_gw ,
375
355
cast_config_weight_for_grad_input = cc_w_gi ,
376
356
cast_config_grad_output_for_grad_weight = cc_go_gw ,
377
- gemm_config_output = gc_o ,
378
- gemm_config_grad_input = gc_gi ,
379
- gemm_config_grad_weight = gc_gw ,
380
357
)
381
358
382
359
else :
0 commit comments