Open
Description
When using FSDP2 for Float8 training, an issue occurs when the number of GPUs exceeds the out_features of an nn.Linear layer. Specifically, FSDP2 splits the weight tensor into a shape of [0, in_features] in some ranks, which causes an error during tensor-wise FP8 training here:
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
To address this, I modified the code as follows:
if x.numel() == 0:
amax = torch.tensor(0., device=x.device, dtype=x.dtype)
else:
amax = torch.max(torch.abs(x))
However, this introduces another issue:
RuntimeError: setStorage: sizes [4, 16], strides [16, 1], storage offset 0, and itemsize 1 requiring a storage size of 64 are out of bounds for storage of size 0
Here is my complete reproducible code:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
class Float8Handler:
def __init__(self,
enable_float8_linear=True, enable_fsdp_float8_all_gather=True,
precompute_float8_dynamic_scale_for_fsdp=True,
scaling_type_input='dynamic', scaling_type_weight='dynamic', scaling_type_grad_output='dynamic',
scaling_granularity_input='tensorwise', scaling_granularity_weight='tensorwise', scaling_granularity_grad_output='tensorwise',
compile=True, pad_inner_dim=False,
):
self.enabled = False
if not enable_float8_linear:
return
try:
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import ScalingGranularity
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity_input = ScalingGranularity(scaling_granularity_input)
scaling_granularity_weight = ScalingGranularity(scaling_granularity_weight)
scaling_granularity_grad_output = ScalingGranularity(scaling_granularity_grad_output)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_input=CastConfig(scaling_type=scaling_type_input, scaling_granularity=scaling_granularity_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight, scaling_granularity=scaling_granularity_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output, scaling_granularity=scaling_granularity_grad_output),
enable_pre_and_post_forward=False,
pad_inner_dim=pad_inner_dim
)
self.enabled = True
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and precompute_float8_dynamic_scale_for_fsdp
)
# for sync_float8_amax_and_scale_history
self.delayed_scaling = (
scaling_type_input == "delayed"
or scaling_type_weight == "delayed"
or scaling_type_grad_output == "delayed"
)
self._sync_float8_amax_and_scale_history = None
self.compile = compile
def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(
model,
config=self.config,
)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(16, 4, bias=True)
def forward(self, x):
return self.fc(x)
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', )
world_size = dist.get_world_size()
world_mesh = init_device_mesh('cuda', (world_size, ), mesh_dim_names=("world",))['world']
float8_handler = Float8Handler(
compile=True,
enable_fsdp_float8_all_gather=True,
pad_inner_dim=True,
)
model = Model().cuda().to(torch.bfloat16)
float8_handler.convert_to_float8_training(model)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)
fully_shard(
model,
mesh=world_mesh,
mp_policy=mp_policy,
reshard_after_forward=True
)
print(model.fc.weight.to_local().shape)
x = torch.randn(16, 16, requires_grad=True, device='cuda', dtype=torch.bfloat16)
out = model(x)
out.mean().backward()
torch 2.6.0+cu126
torchao 0.9.0+cu126
Any suggestions?