|
7 | 7 |
|
8 | 8 | # pyre-strict |
9 | 9 |
|
10 | | -#!/usr/bin/env python3 |
11 | 10 | import copy |
12 | 11 | import itertools |
13 | 12 | import logging |
|
34 | 33 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
35 | 34 | from torch.fx.node import Node |
36 | 35 | from torch.profiler import record_function |
37 | | -from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable |
| 36 | +from torchrec.distributed.dist_data import KJTAllToAll |
38 | 37 | from torchrec.distributed.embedding_sharding import ( |
39 | | - KJTListAwaitable, |
| 38 | + FusedKJTListSplitsAwaitable, |
40 | 39 | KJTListSplitsAwaitable, |
| 40 | + KJTSplitsAllToAllMeta, |
41 | 41 | ) |
42 | 42 | from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule |
43 | 43 |
|
|
59 | 59 | StageOutputWithEvent = Tuple[Optional[StageOut], Optional[torch.cuda.Event]] |
60 | 60 |
|
61 | 61 |
|
62 | | -class Tracer(torch.fx.Tracer): |
63 | | - """ |
64 | | - Disables proxying buffers during tracing. Ideally, proxying buffers would be |
65 | | - disabled, but some models are currently mutating buffer values, which causes errors |
66 | | - during tracing. If those models can be rewritten to not do that, we can likely |
67 | | - remove this line. |
68 | | - """ |
69 | | - |
70 | | - proxy_buffer_attributes = False |
71 | | - |
72 | | - def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: |
73 | | - super().__init__() |
74 | | - self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] |
75 | | - |
76 | | - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: |
77 | | - if ( |
78 | | - isinstance(m, ShardedModule) |
79 | | - or module_qualified_name in self._leaf_modules |
80 | | - or isinstance(m, FSDP) |
81 | | - ): |
82 | | - return True |
83 | | - return super().is_leaf_module(m, module_qualified_name) |
84 | | - |
85 | | - |
86 | | -# TODO: remove after packaging issue is resolved. |
87 | | -class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]): |
88 | | - def __init__( |
89 | | - self, |
90 | | - input_tensors: List[torch.Tensor], |
91 | | - pg: dist.ProcessGroup, |
92 | | - ) -> None: |
93 | | - super().__init__() |
94 | | - self.num_workers: int = pg.size() |
95 | | - |
96 | | - with record_function("## all2all_data:kjt splits ##"): |
97 | | - self._output_tensor: torch.Tensor = torch.empty( |
98 | | - [self.num_workers * len(input_tensors)], |
99 | | - device=input_tensors[0].device, |
100 | | - dtype=input_tensors[0].dtype, |
101 | | - ) |
102 | | - input_tensor = torch.stack(input_tensors, dim=1).flatten() |
103 | | - self._splits_awaitable: dist.Work = dist.all_to_all_single( |
104 | | - output=self._output_tensor, |
105 | | - input=input_tensor, |
106 | | - group=pg, |
107 | | - async_op=True, |
108 | | - ) |
109 | | - |
110 | | - def _wait_impl(self) -> List[List[int]]: |
111 | | - self._splits_awaitable.wait() |
112 | | - return self._output_tensor.view(self.num_workers, -1).T.tolist() |
113 | | - |
114 | | - |
115 | | -# TODO: remove after packaging issue is resolved. |
116 | | -C = TypeVar("C", bound=Multistreamable) |
117 | | -T = TypeVar("T") |
118 | | - |
119 | | - |
120 | | -# TODO: remove after packaging issue is resolved. |
121 | | -def _set_sharding_context_intra_a2a( |
122 | | - tensors_awaitables: List[Awaitable[KeyedJaggedTensor]], |
123 | | - ctx: C, |
124 | | -) -> None: |
125 | | - for awaitable, sharding_context in zip( |
126 | | - tensors_awaitables, |
127 | | - getattr(ctx, "sharding_contexts", []), |
128 | | - ): |
129 | | - if isinstance(awaitable, KJTAllToAllTensorsAwaitable): |
130 | | - if hasattr(sharding_context, "input_splits"): |
131 | | - sharding_context.input_splits = awaitable._input_splits["values"] |
132 | | - if hasattr(sharding_context, "output_splits"): |
133 | | - sharding_context.output_splits = awaitable._output_splits["values"] |
134 | | - if hasattr(sharding_context, "sparse_features_recat"): |
135 | | - sharding_context.sparse_features_recat = awaitable._recat |
136 | | - if ( |
137 | | - hasattr(sharding_context, "batch_size_per_rank") |
138 | | - and awaitable._stride_per_rank is not None |
139 | | - ): |
140 | | - sharding_context.batch_size_per_rank = awaitable._stride_per_rank |
141 | | - |
142 | | - |
143 | | -# TODO: remove after packaging issue is resolved. |
144 | | -@dataclass |
145 | | -class KJTSplitsAllToAllMeta: |
146 | | - pg: dist.ProcessGroup |
147 | | - _input: KeyedJaggedTensor |
148 | | - splits: List[int] |
149 | | - splits_tensors: List[torch.Tensor] |
150 | | - input_splits: List[List[int]] |
151 | | - input_tensors: List[torch.Tensor] |
152 | | - labels: List[str] |
153 | | - keys: List[str] |
154 | | - device: torch.device |
155 | | - stagger: int |
156 | | - |
157 | | - |
158 | | -# TODO: remove after packaging issue is resolved. |
159 | | -def _split(flat_list: List[T], splits: List[int]) -> List[List[T]]: |
160 | | - return [ |
161 | | - flat_list[sum(splits[:i]) : sum(splits[:i]) + n] for i, n in enumerate(splits) |
162 | | - ] |
163 | | - |
164 | | - |
165 | | -# TODO: remove after packaging issue is resolved. |
166 | | -class FusedKJTListSplitsAwaitable(Awaitable[List[KJTListAwaitable]]): |
167 | | - def __init__( |
168 | | - self, |
169 | | - requests: List[KJTListSplitsAwaitable[C]], |
170 | | - contexts: List[C], |
171 | | - pg: Optional[dist.ProcessGroup], |
172 | | - ) -> None: |
173 | | - super().__init__() |
174 | | - self._contexts = contexts |
175 | | - self._awaitables: List[ |
176 | | - Union[KJTSplitsAllToAllMeta, Awaitable[Awaitable[KeyedJaggedTensor]]] |
177 | | - ] = [awaitable for request in requests for awaitable in request.awaitables] |
178 | | - self._output_lengths: List[int] = [ |
179 | | - len(request.awaitables) for request in requests |
180 | | - ] |
181 | | - self._lengths: List[int] = [ |
182 | | - ( |
183 | | - len(awaitable.splits_tensors) |
184 | | - if isinstance(awaitable, KJTSplitsAllToAllMeta) |
185 | | - else 0 |
186 | | - ) |
187 | | - for awaitable in self._awaitables |
188 | | - ] |
189 | | - splits_tensors = [ |
190 | | - splits_tensor |
191 | | - for awaitable in self._awaitables |
192 | | - for splits_tensor in ( |
193 | | - awaitable.splits_tensors |
194 | | - if isinstance(awaitable, KJTSplitsAllToAllMeta) |
195 | | - else [] |
196 | | - ) |
197 | | - ] |
198 | | - self._splits_awaitable: Optional[SplitsAllToAllAwaitable] = ( |
199 | | - SplitsAllToAllAwaitable( |
200 | | - input_tensors=splits_tensors, |
201 | | - pg=pg, |
202 | | - ) |
203 | | - if splits_tensors and pg is not None |
204 | | - else None |
205 | | - ) |
206 | | - |
207 | | - def _wait_impl(self) -> List[KJTListAwaitable]: |
208 | | - if self._splits_awaitable: |
209 | | - splits_list = self._splits_awaitable.wait() |
210 | | - splits_per_awaitable = _split(splits_list, self._lengths) |
211 | | - else: |
212 | | - splits_per_awaitable = [[] for _ in range(len(self._lengths))] |
213 | | - tensors_awaitables = [] |
214 | | - for splits, awaitable in zip(splits_per_awaitable, self._awaitables): |
215 | | - if not splits: # NoWait |
216 | | - assert isinstance(awaitable, Awaitable) |
217 | | - tensors_awaitables.append(awaitable.wait()) |
218 | | - continue |
219 | | - assert isinstance(awaitable, KJTSplitsAllToAllMeta) |
220 | | - if awaitable._input.variable_stride_per_key(): |
221 | | - output_splits = splits |
222 | | - stride_per_rank = None |
223 | | - else: |
224 | | - output_splits = splits[:-1] |
225 | | - stride_per_rank = splits[-1] |
226 | | - tensors_awaitables.append( |
227 | | - KJTAllToAllTensorsAwaitable( |
228 | | - pg=awaitable.pg, |
229 | | - input=awaitable._input, |
230 | | - splits=awaitable.splits, |
231 | | - input_splits=awaitable.input_splits, |
232 | | - output_splits=output_splits, |
233 | | - input_tensors=awaitable.input_tensors, |
234 | | - labels=awaitable.labels, |
235 | | - keys=awaitable.keys, |
236 | | - device=awaitable.device, |
237 | | - stagger=awaitable.stagger, |
238 | | - stride_per_rank=stride_per_rank, |
239 | | - ) |
240 | | - ) |
241 | | - output = [] |
242 | | - awaitables_per_output = _split(tensors_awaitables, self._output_lengths) |
243 | | - for awaitables, ctx in zip(awaitables_per_output, self._contexts): |
244 | | - _set_sharding_context_intra_a2a(awaitables, ctx) |
245 | | - output.append(KJTListAwaitable(awaitables, ctx)) |
246 | | - return output |
247 | | - |
248 | | - |
249 | 62 | @dataclass |
250 | 63 | class TrainPipelineContext: |
251 | 64 | """ |
@@ -462,6 +275,30 @@ def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta: |
462 | 275 | ) |
463 | 276 |
|
464 | 277 |
|
| 278 | +class Tracer(torch.fx.Tracer): |
| 279 | + """ |
| 280 | + Disables proxying buffers during tracing. Ideally, proxying buffers would be |
| 281 | + disabled, but some models are currently mutating buffer values, which causes errors |
| 282 | + during tracing. If those models can be rewritten to not do that, we can likely |
| 283 | + remove this line. |
| 284 | + """ |
| 285 | + |
| 286 | + proxy_buffer_attributes = False |
| 287 | + |
| 288 | + def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: |
| 289 | + super().__init__() |
| 290 | + self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] |
| 291 | + |
| 292 | + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: |
| 293 | + if ( |
| 294 | + isinstance(m, ShardedModule) |
| 295 | + or module_qualified_name in self._leaf_modules |
| 296 | + or isinstance(m, FSDP) |
| 297 | + ): |
| 298 | + return True |
| 299 | + return super().is_leaf_module(m, module_qualified_name) |
| 300 | + |
| 301 | + |
465 | 302 | def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: |
466 | 303 | assert isinstance( |
467 | 304 | batch, (torch.Tensor, Pipelineable) |
|
0 commit comments