1
+ from typing import Any , Optional , Tuple
2
+
1
3
import torch
2
- from torch .utils ._pytree import tree_map
4
+ import torch .utils ._pytree as pytree
5
+ from torch ._prims_common import suggest_memory_format
3
6
4
7
from torchao .prototype .moe_training import _scaled_grouped_mm
5
8
6
-
9
+ # FSDP pads its local tensor on dim-0. The subclass should be preserved such
10
+ # that the padded local tensor (and any transformations like copying to GPU)
11
+ # is of the subclass as well.
12
+ _ops_to_preserve_subclass = {
13
+ torch .ops .aten .empty_like .default ,
14
+ torch .ops .aten .new_zeros .default ,
15
+ torch .ops .aten .slice .Tensor ,
16
+ torch .ops .aten .copy_ .default ,
17
+ torch .ops .aten .view .default ,
18
+ torch .ops .aten .as_strided .default ,
19
+ torch .ops .aten ._to_copy .default ,
20
+ torch .ops .aten ._pin_memory .default ,
21
+ torch .ops .aten .split .Tensor ,
22
+ torch .ops .aten .clone .default ,
23
+ }
24
+
25
+
7
26
class ScaledGroupedMMTensor (torch .Tensor ):
8
27
"""
9
28
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
@@ -13,22 +32,34 @@ class ScaledGroupedMMTensor(torch.Tensor):
13
32
14
33
grouped_mm_func_name = "_grouped_mm"
15
34
offs_arg_name = "offs"
16
- use_triton_for_per_group_scales = True
17
35
18
- def __init__ (
19
- self , data : torch .Tensor , use_triton_for_per_group_scales : bool = True
36
+ @staticmethod
37
+ def __new__ (
38
+ cls ,
39
+ tensor : torch .Tensor ,
20
40
):
21
- self ._data = data
22
- self ._use_triton_for_per_group_scales = use_triton_for_per_group_scales
41
+ return torch .Tensor ._make_wrapper_subclass (
42
+ cls ,
43
+ tensor .size (),
44
+ strides = tensor .stride (),
45
+ storage_offset = tensor .storage_offset (),
46
+ memory_format = suggest_memory_format (tensor ),
47
+ dtype = tensor .dtype ,
48
+ layout = tensor .layout ,
49
+ device = tensor .device ,
50
+ pin_memory = tensor .is_pinned (),
51
+ requires_grad = tensor .requires_grad ,
52
+ )
23
53
24
- def __repr__ ( self ):
25
- return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales= { self . _use_triton_for_per_group_scales } , { self . _data } )"
26
-
27
- def __repr__ ( self ):
28
- return f"ScaledGroupedMMTensor(data= { self ._data } )"
54
+ def __init__ (
55
+ self ,
56
+ tensor : torch . Tensor ,
57
+ ):
58
+ self ._data = tensor
29
59
30
60
@classmethod
31
61
def __torch_function__ (cls , func , types , args , kwargs = {}):
62
+ # override the grouped mm op to use the differentiable _scaled_grouped_mm
32
63
if func .__name__ == cls .grouped_mm_func_name :
33
64
# Use torchao scaled grouped mm with dynamic quant for
34
65
# "2d x 3d with offsets" case (used for routed experts).
@@ -42,32 +73,56 @@ def __torch_function__(cls, func, types, args, kwargs={}):
42
73
B_is_3d = B .dim () == 3
43
74
has_offs = kwargs .get (cls .offs_arg_name ) is not None
44
75
if A_is_2d and B_is_3d and has_offs :
45
- # prefer to use B to check use_triton, as that will be the weight/nn.Parameter
46
- # that is converted to ScaledGroupedMMTensor
47
- use_triton = (
48
- B ._use_triton_for_per_group_scales
49
- if isinstance (B , cls )
50
- else A ._use_triton_for_per_group_scales
51
- )
52
76
return _scaled_grouped_mm (
53
77
* args ,
54
- use_triton_for_per_group_scales = use_triton ,
55
78
** kwargs ,
56
79
)
57
80
58
- # Disable torch_function by hand because we don't want
81
+ # Disable torch_function by hand because we don't want
59
82
# the wrapping behavior of the super() impl, go directly to dispatch
60
- with torch ._C .DisableTorchFunction ():
83
+ # wrap = lambda x: ScaledGroupedMMTensor(x)
84
+ # wrapped_args, wrapped_kwargs = pytree.tree_map_only(torch.Tensor, wrap, (args, kwargs))
85
+ with torch ._C .DisableTorchFunctionSubclass ():
61
86
return func (* args , ** kwargs )
62
87
63
-
64
88
@classmethod
65
89
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
66
- unwrap = lambda x : x ._data if isinstance (x , cls ) else x
67
- wrap = lambda x : cls (x ) if isinstance (x , torch .Tensor ) else x
68
- unwrapped_args , unwrapped_kwargs = tree_map (unwrap , (args , kwargs ))
69
- output = super ().__torch_dispatch__ (func , types , unwrapped_args , unwrapped_kwargs )
70
- wrapped_output = tree_map (wrap , output )
71
- print (func .__name__ )
72
- print (wrapped_output )
73
- return wrapped_output
90
+ # detach is special case
91
+ if func == torch .ops .aten .detach .default :
92
+ return ScaledGroupedMMTensor (args [0 ]._data )
93
+
94
+ # unwrap args and kwargs
95
+ unwrap = lambda tensor : tensor ._data
96
+ args , kwargs = pytree .tree_map_only (
97
+ ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
98
+ )
99
+
100
+ # perform op
101
+ out = func (* args , ** kwargs )
102
+
103
+ # return regular tensors for ops that don't preserve subclass
104
+ if func not in _ops_to_preserve_subclass :
105
+ return out
106
+
107
+ # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
108
+ return pytree .tree_map_only (
109
+ torch .Tensor ,
110
+ lambda x : ScaledGroupedMMTensor (x ),
111
+ out ,
112
+ )
113
+
114
+ def fsdp_pre_all_gather (self , mesh ):
115
+ return (self ._data ,), ()
116
+
117
+ def fsdp_post_all_gather (
118
+ self ,
119
+ all_gather_outputs : Tuple [torch .Tensor , ...],
120
+ metadata : Any ,
121
+ param_dtype : torch .dtype ,
122
+ * ,
123
+ out : Optional [torch .Tensor ] = None ,
124
+ ):
125
+ (data ,) = all_gather_outputs
126
+ return ScaledGroupedMMTensor (
127
+ data ,
128
+ ), (data ,)
0 commit comments