Skip to content

[float8 moe training] fix bug affecting mixed precision training #2451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/prototype/moe_training/test_fsdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py
17 changes: 14 additions & 3 deletions torchao/prototype/moe_training/conversion_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Callable, Optional

from torch import nn
Expand All @@ -8,6 +14,8 @@
register_quantize_module_handler,
)

logger: logging.Logger = logging.getLogger(__name__)


class MoETrainingConfig(AOBaseConfig):
"""
Expand Down Expand Up @@ -76,7 +84,7 @@ def _swap_params(
f"Does not support a root nn.Parameter with children: {module}"
)
if not isinstance(module.data, ScaledGroupedMMTensor):
new_data = ScaledGroupedMMTensor(module.data)
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
return nn.Parameter(new_data, requires_grad=module.requires_grad)
return module

Expand All @@ -102,10 +110,13 @@ def post_order_traversal(
for param_name, param in module.named_parameters(recurse=False):
if not isinstance(param.data, ScaledGroupedMMTensor):
new_param = nn.Parameter(
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
ScaledGroupedMMTensor(param.data, param.data.dtype),
requires_grad=param.requires_grad,
)
setattr(module, param_name, new_param)
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
logger.info(
f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor"
)

post_order_traversal(root_module)
return root_module
54 changes: 46 additions & 8 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Optional, Tuple

import torch
Expand All @@ -6,6 +13,9 @@

from torchao.prototype.moe_training import _scaled_grouped_mm

logger: logging.Logger = logging.getLogger(__name__)


_ops_to_preserve_subclass = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
Expand Down Expand Up @@ -34,14 +44,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
def __new__(
cls,
tensor: torch.Tensor,
dtype: torch.dtype,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
memory_format=suggest_memory_format(tensor),
dtype=tensor.dtype,
dtype=dtype,
layout=tensor.layout,
device=tensor.device,
pin_memory=tensor.is_pinned(),
Expand All @@ -51,11 +62,14 @@ def __new__(
def __init__(
self,
tensor: torch.Tensor,
dtype: torch.dtype,
):
self._data = tensor
self._dtype = dtype

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
# override the grouped mm op to use the differentiable _scaled_grouped_mm
if func.__name__ == cls.grouped_mm_func_name:
# Use torchao scaled grouped mm with dynamic quant for
Expand Down Expand Up @@ -84,10 +98,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
def __torch_dispatch__(cls, func, types, args, kwargs={}):
# detach is special case
if func == torch.ops.aten.detach.default:
return ScaledGroupedMMTensor(args[0]._data)
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)

# unwrap args and kwargs
unwrap = lambda tensor: tensor._data
dtype: Optional[torch.dtype] = None

def unwrap(t):
nonlocal dtype
if dtype is None:
dtype = t._dtype
else:
assert t._dtype == dtype
return t._data

args, kwargs = pytree.tree_map_only(
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
)
Expand All @@ -102,12 +125,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
return pytree.tree_map_only(
torch.Tensor,
lambda x: ScaledGroupedMMTensor(x),
lambda x: ScaledGroupedMMTensor(x, dtype),
out,
)

def __repr__(self):
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"

def __tensor_flatten__(self):
return ["_data"], {"_dtype": self._dtype}

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
return ScaledGroupedMMTensor(
inner_tensors["_data"],
flatten_spec["_dtype"],
)

def fsdp_pre_all_gather(self, mesh):
return (self._data,), ()
all_gather_inputs = (self._data,)
all_gather_metadata = ()
return all_gather_inputs, all_gather_metadata

def fsdp_post_all_gather(
self,
Expand All @@ -118,6 +156,6 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
return ScaledGroupedMMTensor(
data,
), (data,)
output = ScaledGroupedMMTensor(data, param_dtype)
inner_tensors = (data,)
return output, inner_tensors
Loading