Skip to content

Commit c7bca21

Browse files
fsdp working
1 parent 784d087 commit c7bca21

File tree

2 files changed

+86
-32
lines changed

2 files changed

+86
-32
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def _scaled_grouped_mm(
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
4343
"""
44-
print("SCALED_GROUPED_MM")
4544
return _Float8GroupedMM.apply(
4645
A,
4746
B_t,
Lines changed: 86 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
1+
from typing import Any, Optional, Tuple
2+
13
import torch
2-
from torch.utils._pytree import tree_map
4+
import torch.utils._pytree as pytree
5+
from torch._prims_common import suggest_memory_format
36

47
from torchao.prototype.moe_training import _scaled_grouped_mm
58

6-
9+
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
10+
# that the padded local tensor (and any transformations like copying to GPU)
11+
# is of the subclass as well.
12+
_ops_to_preserve_subclass = {
13+
torch.ops.aten.empty_like.default,
14+
torch.ops.aten.new_zeros.default,
15+
torch.ops.aten.slice.Tensor,
16+
torch.ops.aten.copy_.default,
17+
torch.ops.aten.view.default,
18+
torch.ops.aten.as_strided.default,
19+
torch.ops.aten._to_copy.default,
20+
torch.ops.aten._pin_memory.default,
21+
torch.ops.aten.split.Tensor,
22+
torch.ops.aten.clone.default,
23+
}
24+
25+
726
class ScaledGroupedMMTensor(torch.Tensor):
827
"""
928
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
@@ -13,22 +32,34 @@ class ScaledGroupedMMTensor(torch.Tensor):
1332

1433
grouped_mm_func_name = "_grouped_mm"
1534
offs_arg_name = "offs"
16-
use_triton_for_per_group_scales = True
1735

18-
def __init__(
19-
self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True
36+
@staticmethod
37+
def __new__(
38+
cls,
39+
tensor: torch.Tensor,
2040
):
21-
self._data = data
22-
self._use_triton_for_per_group_scales = use_triton_for_per_group_scales
41+
return torch.Tensor._make_wrapper_subclass(
42+
cls,
43+
tensor.size(),
44+
strides=tensor.stride(),
45+
storage_offset=tensor.storage_offset(),
46+
memory_format=suggest_memory_format(tensor),
47+
dtype=tensor.dtype,
48+
layout=tensor.layout,
49+
device=tensor.device,
50+
pin_memory=tensor.is_pinned(),
51+
requires_grad=tensor.requires_grad,
52+
)
2353

24-
def __repr__(self):
25-
return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self._use_triton_for_per_group_scales}, {self._data})"
26-
27-
def __repr__(self):
28-
return f"ScaledGroupedMMTensor(data={self._data})"
54+
def __init__(
55+
self,
56+
tensor: torch.Tensor,
57+
):
58+
self._data = tensor
2959

3060
@classmethod
3161
def __torch_function__(cls, func, types, args, kwargs={}):
62+
# override the grouped mm op to use the differentiable _scaled_grouped_mm
3263
if func.__name__ == cls.grouped_mm_func_name:
3364
# Use torchao scaled grouped mm with dynamic quant for
3465
# "2d x 3d with offsets" case (used for routed experts).
@@ -42,32 +73,56 @@ def __torch_function__(cls, func, types, args, kwargs={}):
4273
B_is_3d = B.dim() == 3
4374
has_offs = kwargs.get(cls.offs_arg_name) is not None
4475
if A_is_2d and B_is_3d and has_offs:
45-
# prefer to use B to check use_triton, as that will be the weight/nn.Parameter
46-
# that is converted to ScaledGroupedMMTensor
47-
use_triton = (
48-
B._use_triton_for_per_group_scales
49-
if isinstance(B, cls)
50-
else A._use_triton_for_per_group_scales
51-
)
5276
return _scaled_grouped_mm(
5377
*args,
54-
use_triton_for_per_group_scales=use_triton,
5578
**kwargs,
5679
)
5780

58-
# Disable torch_function by hand because we don't want
81+
# Disable torch_function by hand because we don't want
5982
# the wrapping behavior of the super() impl, go directly to dispatch
60-
with torch._C.DisableTorchFunction():
83+
# wrap = lambda x: ScaledGroupedMMTensor(x)
84+
# wrapped_args, wrapped_kwargs = pytree.tree_map_only(torch.Tensor, wrap, (args, kwargs))
85+
with torch._C.DisableTorchFunctionSubclass():
6186
return func(*args, **kwargs)
6287

63-
6488
@classmethod
6589
def __torch_dispatch__(cls, func, types, args, kwargs={}):
66-
unwrap = lambda x: x._data if isinstance(x, cls) else x
67-
wrap = lambda x: cls(x) if isinstance(x, torch.Tensor) else x
68-
unwrapped_args, unwrapped_kwargs = tree_map(unwrap, (args, kwargs))
69-
output = super().__torch_dispatch__(func, types, unwrapped_args, unwrapped_kwargs)
70-
wrapped_output = tree_map(wrap, output)
71-
print(func.__name__)
72-
print(wrapped_output)
73-
return wrapped_output
90+
# detach is special case
91+
if func == torch.ops.aten.detach.default:
92+
return ScaledGroupedMMTensor(args[0]._data)
93+
94+
# unwrap args and kwargs
95+
unwrap = lambda tensor: tensor._data
96+
args, kwargs = pytree.tree_map_only(
97+
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
98+
)
99+
100+
# perform op
101+
out = func(*args, **kwargs)
102+
103+
# return regular tensors for ops that don't preserve subclass
104+
if func not in _ops_to_preserve_subclass:
105+
return out
106+
107+
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
108+
return pytree.tree_map_only(
109+
torch.Tensor,
110+
lambda x: ScaledGroupedMMTensor(x),
111+
out,
112+
)
113+
114+
def fsdp_pre_all_gather(self, mesh):
115+
return (self._data,), ()
116+
117+
def fsdp_post_all_gather(
118+
self,
119+
all_gather_outputs: Tuple[torch.Tensor, ...],
120+
metadata: Any,
121+
param_dtype: torch.dtype,
122+
*,
123+
out: Optional[torch.Tensor] = None,
124+
):
125+
(data,) = all_gather_outputs
126+
return ScaledGroupedMMTensor(
127+
data,
128+
), (data,)

0 commit comments

Comments
 (0)