Skip to content

Commit afde31e

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent d963a88 commit afde31e

File tree

5 files changed

+219
-4
lines changed

5 files changed

+219
-4
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
5+
from torchao.float8.float8_utils import compute_error
6+
from torchao.prototype.scaled_grouped_mm.conversion_utils import MoETrainingConfig
7+
from torchao.quantization.quant_api import quantize_
8+
9+
try:
10+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
11+
from torchtitan.experiments.llama4.model.moe import MoE
12+
except ImportError:
13+
import warnings
14+
15+
warnings.warn("torchtitan not installed, skipping MoE tests.")
16+
pytest.skip(allow_module_level=True)
17+
18+
19+
@pytest.mark.parametrize(
20+
"target_fqns",
21+
["experts"],
22+
)
23+
def test_moe_float8_training(target_fqns: list[str]):
24+
model_args = TransformerModelArgs(moe_enabled=True, num_experts=2)
25+
init_std = 0.02
26+
device = torch.device("cuda")
27+
28+
# reference bf16 MoE
29+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
30+
torch.manual_seed(42)
31+
with torch.no_grad():
32+
ref_model.init_weights(init_std, device)
33+
34+
# target MoE for testing conversion
35+
model = MoE(model_args).to(torch.bfloat16).cuda()
36+
torch.manual_seed(42)
37+
with torch.no_grad():
38+
model.init_weights(init_std, device)
39+
40+
# assert starting params are identical for both models
41+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
42+
assert torch.equal(param1, param2)
43+
44+
# convert MoE to float8 training
45+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
46+
for target_fqn in target_fqns:
47+
if target_fqn in cur_fqn:
48+
return True
49+
return False
50+
51+
config = MoETrainingConfig()
52+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
53+
54+
# inputs
55+
batch, seq, dim = 1, 8192, 4096
56+
x = torch.randn(batch, seq, dim, dtype=torch.bfloat16, requires_grad=True).cuda()
57+
ref_x = x.clone()
58+
59+
# forward pass
60+
out = model(x)
61+
ref_out = ref_model(ref_x)
62+
63+
# validate SQNR between outputs is acceptable.
64+
# a single fp8 gemm uses SQNR >= 25.0 for testing, so for a full MoE layer
65+
# we'll use a slightly lower threshold.
66+
out_sqnr = compute_error(out, ref_out)
67+
assert out_sqnr.item() >= 23.0, f"SQNR must be >= 23.0, got {out_sqnr.item()}."
68+
69+
# backward pass
70+
ref_out.sum().backward()
71+
out.sum().backward()
72+
73+
# validate input gradients
74+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
75+
assert input_grad_sqnr.item() >= 23.0, (
76+
f"SQNR must be >= 23.0, got {input_grad_sqnr.item()}."
77+
)
78+
79+
# validate param gradients
80+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
81+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
82+
assert param_grad_sqnr.item() >= 23.0, (
83+
f"SQNR must be >= 23.0, got {param_grad_sqnr.item()}."
84+
)

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Float8LinearConfig,
2424
Float8LinearRecipeName,
2525
)
26-
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
26+
from torchao.float8.float8_linear import _matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
2929
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
@@ -183,7 +183,7 @@ def compute_reference_forward(
183183

184184
# Validate each actual result group from the _scaled_grouped_mm is equal to:
185185
# 1. A manual _scaled_mm for the group.
186-
# 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients).
186+
# 2. A _matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients).
187187
outputs = []
188188
list1 = list(zip(A_list_fp8, B_t_fp8, A_scale_list, B_t_scales, result_list))
189189
list2 = list(zip(A_list, B_t, result_list))
@@ -199,7 +199,7 @@ def compute_reference_forward(
199199
use_fast_accum=float8_config.gemm_config_output.use_fast_accum,
200200
)
201201
a2, b2, result2 = list2[i]
202-
ref_group_result2 = matmul_with_hp_or_float8_args.apply(
202+
ref_group_result2 = _matmul_with_hp_or_float8_args.apply(
203203
a2,
204204
b2,
205205
LinearMMConfig(),
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.quantization.transform_module import (
7+
register_quantize_module_handler,
8+
)
9+
from torchao.prototype.scaled_grouped_mm.tensor import ScaledGroupedMMTensor
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
pass
14+
15+
16+
@register_quantize_module_handler(MoETrainingConfig)
17+
def _moe_training_transform(
18+
module: nn.Module,
19+
config: MoETrainingConfig,
20+
) -> nn.Module:
21+
"""
22+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
23+
24+
Args:
25+
module: Module to modify.
26+
config: MoETrainingConfig which defines how to perform the MoE training transform.
27+
28+
Returns:
29+
nn.Module: The modified module with swapped parameters.
30+
"""
31+
out = swap_params(module)
32+
return out
33+
34+
def swap_params(
35+
module: nn.Module,
36+
*,
37+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
38+
) -> nn.Module:
39+
"""
40+
Recurses through the nn.Module, recursively swapping the data tensor of
41+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
42+
passed the module_filter_fn, if specified.
43+
44+
Args:
45+
module: Module to modify.
46+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
47+
that pass the filter function will be swapped. The inputs to the
48+
filter function are the module instance, and the FQN.
49+
50+
Returns:
51+
nn.Module: The modified module with swapped linear layers.
52+
"""
53+
if isinstance(module, nn.Parameter) and (
54+
module_filter_fn is None or module_filter_fn(module, "")
55+
):
56+
if len(list(module.children())) > 0:
57+
raise AssertionError(
58+
f"Does not support a root nn.Parameter with children: {module}"
59+
)
60+
if not isinstance(module.data, ScaledGroupedMMTensor):
61+
new_data = ScaledGroupedMMTensor(module.data)
62+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
63+
return module
64+
65+
root_module = module
66+
67+
def post_order_traversal(
68+
module: nn.Module,
69+
cur_fqn: Optional[str] = None,
70+
parent_module: Optional[nn.Module] = None,
71+
):
72+
if cur_fqn is None:
73+
cur_fqn = ""
74+
75+
for child_module_name, child_module in module.named_children():
76+
if cur_fqn == "":
77+
new_fqn = child_module_name
78+
else:
79+
new_fqn = f"{cur_fqn}.{child_module_name}"
80+
81+
post_order_traversal(child_module, new_fqn, module)
82+
83+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
84+
for param_name, param in module.named_parameters(recurse=False):
85+
if not isinstance(param.data, ScaledGroupedMMTensor):
86+
new_param = nn.Parameter(
87+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
88+
)
89+
setattr(module, param_name, new_param)
90+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
91+
92+
post_order_traversal(root_module)
93+
return root_module

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)