Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add option for recomputing the casted weight during backwards #186

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
18 changes: 15 additions & 3 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Experiment:
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
recompute_weight_cast: bool = False

# 3 Times since we are calculating forward backward
@property
Expand Down Expand Up @@ -95,9 +96,14 @@ def main(
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
recompute_weight_casts = [True, False]
experiment_list: List[Experiment] = []
for idx, (dtype, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
for idx, (dtype, (name, (K, N)), recompute_weight_cast) in enumerate(
tqdm(
list(
product(ref_dtypes, name_to_shapes_70b.items(), recompute_weight_casts)
)
)
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -106,7 +112,9 @@ def main(
)

linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref), emulate=False
copy.deepcopy(linear_ref),
emulate=False,
recompute_weight_cast=recompute_weight_cast,
)

bsz, seq_len = 4, 4096
Expand Down Expand Up @@ -155,6 +163,7 @@ def wrapper(*args, **kwargs):
float8_time,
dtype,
compile,
recompute_weight_cast=recompute_weight_cast,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
Expand All @@ -169,6 +178,7 @@ def wrapper(*args, **kwargs):
"ref_dtype",
"compiled",
"fp8_dtype",
"recompute_weight_cast",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
Expand All @@ -187,6 +197,7 @@ def wrapper(*args, **kwargs):
experiment.dtype,
experiment.compiled,
experiment.float_8_dtype,
experiment.recompute_weight_cast,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
Expand Down Expand Up @@ -214,6 +225,7 @@ def wrapper(*args, **kwargs):
"shape",
"ref_dtype",
"compiled",
"recompute_weight_cast",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
Expand Down
8 changes: 6 additions & 2 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ class LinearParams:
torch_compile: Optional[bool] = False


def main(profile_path: Path, compile: bool, linear_type: str):
def main(
profile_path: Path, compile: bool, linear_type: str, recompute_weight_cast: bool
):
profile_path = Path(profile_path)
assert profile_path.is_dir(), f"Path {profile_path} must be a directory"
params = LinearParams(
Expand All @@ -110,7 +112,9 @@ def main(profile_path: Path, compile: bool, linear_type: str):
dtype=params.ref_dtype,
)
linear_type = LinearType[linear_type.upper()]
linear_float8 = get_float8_linear(linear_type, linear_ref)
linear_float8 = get_float8_linear(
linear_type, linear_ref, recompute_weight_cast=recompute_weight_cast
)

input_tensor = torch.randn(
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True
Expand Down
41 changes: 34 additions & 7 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
"""

import torch
from float8_experimental.float8_ops import float8_linear

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
from float8_experimental.float8_utils import (
get_maybe_autocast_inputs,
tensor_to_scale,
to_fp8_saturated,
)


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -47,14 +52,31 @@ class Float8DynamicLinear(torch.nn.Linear):
"""

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
w_fp8 = self.cast_to_float8(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Tried to do this with @custom_fwd/bwd but it didn't work
temp_x, temp_weight, temp_bias = get_maybe_autocast_inputs(
x, self.weight, self.bias
)
x_fp8 = self.cast_to_float8(temp_x)
weight_scale = tensor_to_scale(temp_weight, torch.float8_e4m3fn)
y = float8_linear(
x_fp8,
temp_weight,
None, # bias
weight_scale,
None,
self.emulate,
self.recompute_weight_cast,
)
# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)

# TODO We should use addmm above but this fails the single fsdp test:
# FAILED: _orig_mod.0.fp8_amax_w, 0.2197265625, 0.21875
# Not immediately clear why the bias being fused in would only effect the numerics
# for the weight....
if temp_bias is not None:
y = y + temp_bias

return y

def cast_to_float8(self, inpt_tensor):
Expand All @@ -67,17 +89,22 @@ def cast_to_float8e5m2_bw(self, gradY):
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)

@classmethod
def from_float(cls, mod, emulate: bool = False):
def from_float(
cls, mod, emulate: bool = False, recompute_weight_cast: bool = False
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
recompute_weight_cast (bool): whether to recompute the weight cast on every
backwards pass
"""
with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.recompute_weight_cast = recompute_weight_cast
return new_mod
62 changes: 48 additions & 14 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import float8_experimental.config as config

import torch
from float8_experimental.float8_ops import float8_linear

from float8_experimental.float8_tensor import Float8Tensor

from float8_experimental.float8_utils import (
amax_history_to_scale,
E4M3_MAX_POS,
E5M2_MAX_POS,
get_maybe_autocast_inputs,
tensor_to_amax,
to_fp8_saturated,
)
Expand Down Expand Up @@ -172,6 +174,13 @@ def __init__(self, *args, **kwargs):
# and torch.compile, this option can disable them
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward

# This flag is used to modify what gets saved for backwards. Its default value
# is False, this saves the casted weight for backwards. Note that this typically increases memory usage
# Because both the weight parameter and the casted weight are saved on device. If set to true
# this will only save the weight parameter and during the backwards pass it will re-cast this weight to fp8.
# For traditional FSDP this should be set to True in order to not save the un-sharded weight for backwards.
self.recompute_weight_cast = False

def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
) -> None:
Expand All @@ -191,14 +200,6 @@ def convert_amax_buffer_to_float32(self):
def cast_x_to_float8(
self, x: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
x = x.to(autocast_dtype)

scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
Expand All @@ -214,6 +215,20 @@ def cast_x_to_float8(
)
return x_fp8

def _maybe_init_amaxes_scales_weight(
self, w: torch.Tensor, is_amax_initialized: bool
):
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
torch.float8_e4m3fn,
is_amax_initialized,
)

def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
Expand Down Expand Up @@ -281,30 +296,48 @@ class Float8Linear(Float8LinearMixin, torch.nn.Linear):
"""

def forward(self, x):
temp_x, temp_weight, temp_bias = get_maybe_autocast_inputs(
x, self.weight, self.bias
)
self.float8_pre_forward(x)

x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
x_fp8 = self.cast_x_to_float8(temp_x, self.is_amax_initialized)
self._maybe_init_amaxes_scales_weight(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
y = float8_linear(
x_fp8,
temp_weight,
None, # bias
self.fp8_scale_w,
self.fp8_amax_w,
self.emulate,
self.recompute_weight_cast,
)

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y, self.emulate)

if self.bias is not None:
y = y + self.bias.to(y.dtype)
# TODO We should use addmm above but this fails the single fsdp test:
# FAILED: _orig_mod.0.fp8_amax_w, 0.2197265625, 0.21875
# Not immediately clear why the bias being fused in would only effect the numerics
# for the weight....
if temp_bias is not None:
y = y + temp_bias

self.float8_post_forward()
return y

@classmethod
def from_float(cls, mod, emulate: bool = False):
def from_float(
cls, mod, emulate: bool = False, recompute_weight_cast: bool = False
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
recompute_weight_cast (bool): whether to recompute the casted weight for backwards
"""
# TODO Follow up! This is a great idea but we need the mixin base to create real
# Tensors and the Linear base to create empty params
Expand All @@ -313,6 +346,7 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.recompute_weight_cast = recompute_weight_cast
# I think its okay to send all params and buffers to device
new_mod.to(mod.weight.device)
return new_mod
28 changes: 19 additions & 9 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import copy
from enum import auto, Enum
from typing import List, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -23,13 +24,17 @@ class LinearType(Enum):


def get_float8_linear(
linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False
linear_type: LinearType,
linear_ref: torch.nn.Linear,
emulate: bool = False,
recompute_weight_cast: bool = False,
):
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
Args:
linear_type: The type of Float8Linear to return.
linear_ref: The linear module to initialize from.
emulate: Whether to emulate the fp8 matmul logic in float32.
recompute_weight_cast: Whether to recompute the weight cast in the backwards pass.
"""
LINEAR_TYPE_MAP = {
LinearType.DELAYED: Float8Linear,
Expand All @@ -39,7 +44,9 @@ def get_float8_linear(
raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}")

return LINEAR_TYPE_MAP[linear_type].from_float(
copy.deepcopy(linear_ref), emulate=emulate
copy.deepcopy(linear_ref),
emulate=emulate,
recompute_weight_cast=recompute_weight_cast,
)


Expand All @@ -59,11 +66,12 @@ def _update_history_with_new_amax(new_amax, amax_history):


def swap_linear_with_float8_linear(
model,
module,
emulate=False,
skip_fqn_list=None,
cur_fqn="",
model: torch.nn.Module,
module: Union[Float8Linear, Float8DynamicLinear],
emulate: bool = False,
skip_fqn_list: Optional[List[str]] = None,
cur_fqn: str = "",
recompute_weight_cast: bool = False,
):
"""
Replaces all instances of torch.nn.Linear in the given model with module.
Expand All @@ -74,17 +82,19 @@ def swap_linear_with_float8_linear(
emulate (bool, optional): Whether to emulate the fp8 matmul logic in float32.
skip_fqn_list (List[str], optional): If specified, a list of FQNs to skip
cur_fqn (str, optional): Current fqn, used to implement skip_fqn_list
recompute_weight_cast (bool, optional): Whether to recompute the weight cast in the backwards pass.
"""
args = (module, emulate, skip_fqn_list, cur_fqn, recompute_weight_cast)
name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}"
if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and isinstance(
child, torch.nn.Linear
):
new_child = module.from_float(child, emulate)
new_child = module.from_float(child, emulate, recompute_weight_cast)
setattr(model, name, new_child)
else:
swap_linear_with_float8_linear(child, module, emulate)
swap_linear_with_float8_linear(child, *args)


def get_float8_layers(model: torch.nn.Module, fp8_classes=None):
Expand Down
Loading