Skip to content

[Bug] FSDP2 FP8 compatibility problem with nn.Linear layers (GPU count > out_features) #1938

Open
@HIT-cwh

Description

@HIT-cwh

When using FSDP2 for Float8 training, an issue occurs when the number of GPUs exceeds the out_features of an nn.Linear layer. Specifically, FSDP2 splits the weight tensor into a shape of [0, in_features] in some ranks, which causes an error during tensor-wise FP8 training here:

RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Image

To address this, I modified the code as follows:

if x.numel() == 0:
    amax = torch.tensor(0., device=x.device, dtype=x.dtype)
else:
    amax = torch.max(torch.abs(x))

However, this introduces another issue:

RuntimeError: setStorage: sizes [4, 16], strides [16, 1], storage offset 0, and itemsize 1 requiring a storage size of 64 are out of bounds for storage of size 0

Image

Here is my complete reproducible code:

import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed._composable.fsdp import (
    CPUOffloadPolicy,
    MixedPrecisionPolicy,
    fully_shard,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

class Float8Handler:
    def __init__(self, 
        enable_float8_linear=True, enable_fsdp_float8_all_gather=True,
        precompute_float8_dynamic_scale_for_fsdp=True,
        scaling_type_input='dynamic', scaling_type_weight='dynamic', scaling_type_grad_output='dynamic',
        scaling_granularity_input='tensorwise', scaling_granularity_weight='tensorwise', scaling_granularity_grad_output='tensorwise',
        compile=True, pad_inner_dim=False,
    ):
        self.enabled = False

        if not enable_float8_linear:
            return

        try:
            from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
            from torchao.float8.config import ScalingGranularity
        except ImportError as e:
            raise ImportError(
                "torchao is not installed. Please install it to use float8 linear layers."
            ) from e

        # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
        enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather
        
        scaling_type_input = ScalingType(scaling_type_input)
        scaling_type_weight = ScalingType(scaling_type_weight)
        scaling_type_grad_output = ScalingType(scaling_type_grad_output)
        scaling_granularity_input = ScalingGranularity(scaling_granularity_input)
        scaling_granularity_weight = ScalingGranularity(scaling_granularity_weight)
        scaling_granularity_grad_output = ScalingGranularity(scaling_granularity_grad_output)
        self.config = Float8LinearConfig(
            enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
            cast_config_input=CastConfig(scaling_type=scaling_type_input, scaling_granularity=scaling_granularity_input),
            cast_config_weight=CastConfig(scaling_type=scaling_type_weight, scaling_granularity=scaling_granularity_weight),
            cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output, scaling_granularity=scaling_granularity_grad_output),
            enable_pre_and_post_forward=False,
            pad_inner_dim=pad_inner_dim
        )

        self.enabled = True

        # for precompute_float8_dynamic_scale_for_fsdp
        self.precompute_scale = (
            enable_fsdp_float8_all_gather
            and precompute_float8_dynamic_scale_for_fsdp
        )

        # for sync_float8_amax_and_scale_history
        self.delayed_scaling = (
            scaling_type_input == "delayed"
            or scaling_type_weight == "delayed"
            or scaling_type_grad_output == "delayed"
        )
        self._sync_float8_amax_and_scale_history = None
        self.compile = compile

    def convert_to_float8_training(self, model: nn.Module):
        """
        This function converts the linear layers of `model` to `Float8Linear`.
        Note that today, only dynamic tensor scaling (the default) is supported.
        This will mutate the model inplace.
        """
        
        if not self.enabled:
            return

        from torchao.float8 import convert_to_float8_training

        convert_to_float8_training(
            model,
            config=self.config,
        )


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(16, 4, bias=True)
    
    def forward(self, x):
        return self.fc(x)

rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', )

world_size = dist.get_world_size()
world_mesh = init_device_mesh('cuda', (world_size, ), mesh_dim_names=("world",))['world']
float8_handler = Float8Handler(
    compile=True,
    enable_fsdp_float8_all_gather=True,
    pad_inner_dim=True,
)

model = Model().cuda().to(torch.bfloat16)
float8_handler.convert_to_float8_training(model)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)
fully_shard(
    model,
    mesh=world_mesh,
    mp_policy=mp_policy,
    reshard_after_forward=True
)

print(model.fc.weight.to_local().shape)

x = torch.randn(16, 16, requires_grad=True, device='cuda', dtype=torch.bfloat16)
out = model(x)
out.mean().backward()
torch                      2.6.0+cu126
torchao                  0.9.0+cu126

Any suggestions?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions