Skip to content

Commit 12396c6

Browse files
authored
[cleanup][3/x] unify dynamic input and grad_output casting (#1480)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 2ec9bc1 commit 12396c6

File tree

2 files changed

+10
-62
lines changed

2 files changed

+10
-62
lines changed

benchmarks/float8/profile_linear_float8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def main(
355355
1, 2048, 4096, device=device, dtype=ref_dtype
356356
).requires_grad_()
357357
else:
358-
M, K, N = 4096, 4096, 4096
358+
M, K, N = 2048, 4096, 8192
359359
m_ref = torch.nn.Sequential(
360360
torch.nn.Linear(K, N, bias=False),
361361
)

torchao/float8/float8_linear.py

Lines changed: 9 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1616
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1717
from torchao.float8.float8_scaling_utils import (
18-
NoopFwToFloat8BwDynamic,
1918
get_maybe_axiswise_dim,
2019
hp_tensor_to_float8_dynamic,
2120
)
@@ -29,33 +28,6 @@
2928
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
3029

3130

32-
def _cast_input_to_float8(
33-
input: torch.Tensor,
34-
scaling_type_input: ScalingType,
35-
config: Float8LinearConfig,
36-
linear_mm_config: LinearMMConfig,
37-
) -> torch.Tensor:
38-
# Duplicate the autocast logic for F.linear, so that the output
39-
# of our module has the right original precision
40-
if torch.is_autocast_enabled():
41-
# For now, hardcode to GPU's autocast dtype
42-
# if we need CPU support in the future, we can add it
43-
autocast_dtype = torch.get_autocast_gpu_dtype()
44-
input = input.to(autocast_dtype)
45-
46-
if tensor_already_casted_to_fp8(input):
47-
input_fp8 = input
48-
else:
49-
assert scaling_type_input is ScalingType.DYNAMIC
50-
input_fp8 = hp_tensor_to_float8_dynamic(
51-
input,
52-
config.cast_config_input.target_dtype,
53-
linear_mm_config,
54-
gemm_input_role=GemmInputRole.INPUT,
55-
)
56-
return input_fp8
57-
58-
5931
def _get_weight_scale(
6032
weight: torch.Tensor,
6133
scaling_type_weight: ScalingType,
@@ -85,21 +57,6 @@ def _cast_weight_to_float8_t(
8557
return weight_fp8.t()
8658

8759

88-
def _cast_output_to_float8_in_bw(
89-
output: torch.Tensor,
90-
scaling_type_grad_output,
91-
linear_mm_config: LinearMMConfig,
92-
config: Float8LinearConfig,
93-
) -> torch.Tensor:
94-
assert scaling_type_grad_output is ScalingType.DYNAMIC
95-
output = NoopFwToFloat8BwDynamic.apply(
96-
output,
97-
linear_mm_config,
98-
config.cast_config_grad_output.target_dtype,
99-
)
100-
return output
101-
102-
10360
@torch._dynamo.allow_in_graph
10461
class matmul_with_hp_or_float8_args(torch.autograd.Function):
10562
"""
@@ -329,6 +286,14 @@ def __init__(self, *args, **kwargs):
329286
)
330287

331288
def forward(self, input: torch.Tensor) -> torch.Tensor:
289+
# Duplicate the autocast logic for F.linear, so that the output
290+
# of our module has the right original precision
291+
if torch.is_autocast_enabled():
292+
# For now, hardcode to GPU's autocast dtype
293+
# if we need CPU support in the future, we can add it
294+
autocast_dtype = torch.get_autocast_gpu_dtype()
295+
input = input.to(autocast_dtype)
296+
332297
has_any_axiswise_scaling = any(
333298
cc.scaling_granularity is ScalingGranularity.AXISWISE
334299
for cc in [
@@ -341,18 +306,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
341306
]
342307
)
343308

344-
input_maybe_fp8 = input
345309
weight_maybe_fp8_t = self.weight.t()
346310

347311
# TODO(future PR): check for axiswise scaling for input, weight,
348312
# grad_output separately instead of together
349313
if not has_any_axiswise_scaling:
350-
input_fp8 = _cast_input_to_float8(
351-
input,
352-
self.scaling_type_input,
353-
self.config,
354-
self.linear_mm_config,
355-
)
356314
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
357315
# weight_scale should be saved.
358316
weight_scale = _get_weight_scale(
@@ -375,25 +333,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
375333
weight_scale,
376334
)
377335

378-
input_maybe_fp8 = input_fp8
379336
weight_maybe_fp8_t = weight_fp8_t
380337

381338
output = matmul_with_hp_or_float8_args.apply(
382-
input_maybe_fp8,
339+
input,
383340
weight_maybe_fp8_t,
384341
self.linear_mm_config,
385342
self.config,
386343
)
387344

388-
if not has_any_axiswise_scaling:
389-
# Cast grad_output to float8_e5m2 during backward
390-
output = _cast_output_to_float8_in_bw(
391-
output,
392-
self.scaling_type_grad_output,
393-
self.linear_mm_config,
394-
self.config,
395-
)
396-
397345
if self.bias is not None:
398346
output = output + self.bias.to(output.dtype)
399347
return output

0 commit comments

Comments
 (0)