Skip to content

Commit 450f140

Browse files
author
Daniel Vega-Myhre
committed
address comments
1 parent 2af8c14 commit 450f140

File tree

3 files changed

+22
-74
lines changed

3 files changed

+22
-74
lines changed

torchao/prototype/float8nocompile/float8nocompile_linear.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,11 @@ def __init__(self, *args, **kwargs):
4141
Additional arguments on top of `torch.nn.Linear`'s arguments:
4242
* `config`: Float8LinearConfig
4343
"""
44-
45-
# Amax scales should always be kept as float32.
46-
self.always_float32_buffers = set()
4744
config = kwargs.pop("config")
4845
emulate = config.emulate
4946
super().__init__(*args, **kwargs)
5047

51-
# Defines the scaling behavior of input, weight, grad_output
52-
self.scaling_type_input = config.cast_config_input.scaling_type
53-
self.scaling_type_weight = config.cast_config_weight.scaling_type
54-
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
55-
5648
self.config = config
57-
self.is_amax_initialized = not self.config.enable_amax_init
5849

5950
self.linear_mm_config = LinearMMConfig(
6051
# output
@@ -81,31 +72,18 @@ def __init__(self, *args, **kwargs):
8172
)
8273

8374
def forward(self, input: torch.Tensor) -> torch.Tensor:
84-
# TODO(danielvegamyhre): modify to support for FSDP once dependencies are implemented
85-
output = self.forward_fp8_matmul(input)
86-
if self.bias is not None:
87-
output = output + self.bias.to(output.dtype)
88-
return output
89-
90-
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
91-
# perform hp to fp8 conversions
92-
# TODO(danielvegamyhre): replace conversion with triton kernels
93-
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
94-
weight_scale = self.get_weight_scale(self.weight)
95-
weight_fp8_t = self.cast_weight_to_float8_t(
96-
self.weight, self.is_amax_initialized, weight_scale
97-
)
75+
# TODO(danielvegamyhre): replace conversions with triton kernels
76+
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
77+
input_fp8 = self.cast_input_to_float8(input)
78+
weight_fp8_t = self.cast_weight_to_float8_t(self.weight)
9879

9980
# compute fp8 matmul
10081
output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t)
10182

10283
# cast grad_output to float8_e5m2 during backward
103-
# TODO(danielvegamyhre): replace with triton kernel
10484
return self.cast_output_to_float8_in_bw(output)
10585

106-
def cast_input_to_float8(
107-
self, input: torch.Tensor, is_amax_initialized: bool
108-
) -> torch.Tensor:
86+
def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
10987
# Duplicate the autocast logic for F.linear, so that the output
11088
# of our module has the right original precision
11189
if torch.is_autocast_enabled():
@@ -122,32 +100,21 @@ def cast_input_to_float8(
122100
gemm_input_role=GemmInputRole.INPUT,
123101
)
124102

125-
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
126-
# TODO(danielvegamyhre): replace scale calculation with triton kernel
127-
if tensor_already_casted_to_fp8(weight):
128-
return None
129-
return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype)
130-
131103
def cast_weight_to_float8_t(
132104
self,
133105
weight: torch.Tensor,
134-
is_amax_initialized: bool,
135-
weight_scale: Optional[torch.Tensor] = None,
136106
) -> torch.Tensor:
137-
if tensor_already_casted_to_fp8(weight):
138-
return weight.t()
139-
140107
# TODO(danielvegamyhre): replace conversion with triton kernel
141-
weight_fp8 = hp_tensor_and_scale_to_float8(
108+
weight_fp8 = hp_tensor_to_float8nocompile_dynamic(
142109
weight,
143-
weight_scale,
144110
self.config.cast_config_weight.target_dtype,
145111
self.linear_mm_config,
146112
gemm_input_role=GemmInputRole.WEIGHT,
147113
)
148114
return weight_fp8.t()
149115

150116
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
117+
# casts grad_output to float8_e5m2 for backward
151118
# TODO(danielvegamyhre): replace conversion with triton kernel
152119
return NoopFwToFloat8BwDynamic.apply(
153120
output,
@@ -156,20 +123,15 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
156123
)
157124

158125
@classmethod
159-
def from_float(
160-
cls,
161-
mod,
162-
config: Optional[Float8LinearConfig] = None,
163-
):
126+
def from_float(cls, mod):
164127
"""
165128
Create an nn.Linear with fp8 compute from a regular nn.Linear
166129
167130
Args:
168131
mod (torch.nn.Linear): nn.Linear to convert
169132
config (Optional[Float8LinearConfig]): configuration for conversion to float8
170133
"""
171-
if config is None:
172-
config = Float8LinearConfig()
134+
config = Float8LinearConfig()
173135
with torch.device("meta"):
174136
new_mod = cls(
175137
mod.in_features,

torchao/prototype/float8nocompile/float8nocompile_linear_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def convert_to_float8_nocompile_training(
2424
module: nn.Module,
2525
*,
2626
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
27-
config: Float8LinearConfig = None,
2827
) -> nn.Module:
2928
"""
3029
Swaps `torch.nn.Linear` in `module` with `Float8LinearNoCompile`.
@@ -39,12 +38,7 @@ def convert_to_float8_nocompile_training(
3938
Returns:
4039
nn.Module: The modified module with swapped linear layers.
4140
"""
42-
if config is None:
43-
config = Float8LinearConfig()
44-
from_float = lambda m: Float8LinearNoCompile.from_float(
45-
m,
46-
config=config,
47-
)
41+
from_float = lambda m: Float8LinearNoCompile.from_float(m)
4842
return swap_linear_layers(
4943
module,
5044
from_float,

torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,23 @@
1515
from torchao.float8.config import ScalingGranularity
1616
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1717
from torchao.float8.float8_tensor import (
18+
_ToFloat8ConstrFunc,
1819
Float8Tensor,
1920
GemmInputRole,
20-
hp_tensor_and_scale_to_float8,
2121
LinearMMConfig,
2222
)
2323
from torchao.float8.float8_utils import tensor_to_scale
2424

25+
# avoid division by zero when calculating scale
26+
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
27+
EPS = 1e-12
28+
2529

2630
def hp_tensor_to_float8nocompile_dynamic(
2731
hp_tensor: torch.Tensor,
2832
float8_dtype: torch.dtype,
2933
linear_mm_config: LinearMMConfig,
30-
reduce_amax: bool = False,
3134
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
32-
device_mesh=None,
33-
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
34-
axiswise_dim: Optional[int] = None,
3535
) -> Float8Tensor:
3636
"""
3737
Given a high precision tensor `hp_tensor`,
@@ -42,28 +42,20 @@ def hp_tensor_to_float8nocompile_dynamic(
4242
float8_dtype: the float8 dtype to use
4343
linear_mm_config: Defines the configuration for the scaled_mm for
4444
the 3 fwd/bwd gemms of linear
45-
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
4645
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
4746
the 3 fwd/bwd gemms of linear
48-
scaling_granularity: Defines the scaling granularity
49-
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
5047
"""
5148
# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel
52-
if tensor_already_casted_to_fp8(hp_tensor):
53-
return hp_tensor
54-
scale = tensor_to_scale(
55-
hp_tensor,
56-
float8_dtype,
57-
reduce_amax,
58-
device_mesh,
59-
scaling_granularity,
60-
axiswise_dim,
61-
)
62-
return hp_tensor_and_scale_to_float8(
49+
# torch.compile and eager show different numerics for 1.0 / float32,
50+
# upcast to float64 to ensure same numeric between compile and eager
51+
amax = torch.max(torch.abs(hp_tensor)).to(torch.float64)
52+
scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
53+
scale = scale.to(torch.float32) # scale must be fp32
54+
return _ToFloat8ConstrFunc.apply(
6355
hp_tensor,
6456
scale,
6557
float8_dtype,
6658
linear_mm_config,
6759
gemm_input_role,
68-
axiswise_dim,
60+
None,
6961
)

0 commit comments

Comments
 (0)