Skip to content

Commit 0553999

Browse files
joshuadengfacebook-github-bot
authored andcommitted
remove duplicate modules from train pipeline and import from embedding_sharding (#1754)
Summary: Pull Request resolved: #1754 due to issues with packaging for train pipeline in the past we needed to copy over modules into train_pipeline directly. Since the changes have since propagated through packaging we can safely remove the duplicate modules/functions. Reviewed By: zainhuda Differential Revision: D54497596 fbshipit-source-id: a6f532de653469dc4ed52a89ccda8e76533e4929
1 parent eb9c6d8 commit 0553999

File tree

2 files changed

+27
-195
lines changed

2 files changed

+27
-195
lines changed

torchrec/distributed/train_pipeline/train_pipeline.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
# pyre-strict
99

10-
"""
11-
NOTE: Due to an internal packaging issue, `train_pipeline.py` must be compatible with
12-
older versions of TorchRec. Importing new modules from other files may break model
13-
publishing flows.
14-
"""
1510
import abc
1611
import logging
1712
from typing import cast, Generic, Iterator, List, Optional, Tuple

torchrec/distributed/train_pipeline/utils.py

Lines changed: 27 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# pyre-strict
99

10-
#!/usr/bin/env python3
1110
import copy
1211
import itertools
1312
import logging
@@ -34,10 +33,11 @@
3433
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
3534
from torch.fx.node import Node
3635
from torch.profiler import record_function
37-
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
36+
from torchrec.distributed.dist_data import KJTAllToAll
3837
from torchrec.distributed.embedding_sharding import (
39-
KJTListAwaitable,
38+
FusedKJTListSplitsAwaitable,
4039
KJTListSplitsAwaitable,
40+
KJTSplitsAllToAllMeta,
4141
)
4242
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
4343

@@ -59,193 +59,6 @@
5959
StageOutputWithEvent = Tuple[Optional[StageOut], Optional[torch.cuda.Event]]
6060

6161

62-
class Tracer(torch.fx.Tracer):
63-
"""
64-
Disables proxying buffers during tracing. Ideally, proxying buffers would be
65-
disabled, but some models are currently mutating buffer values, which causes errors
66-
during tracing. If those models can be rewritten to not do that, we can likely
67-
remove this line.
68-
"""
69-
70-
proxy_buffer_attributes = False
71-
72-
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
73-
super().__init__()
74-
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
75-
76-
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
77-
if (
78-
isinstance(m, ShardedModule)
79-
or module_qualified_name in self._leaf_modules
80-
or isinstance(m, FSDP)
81-
):
82-
return True
83-
return super().is_leaf_module(m, module_qualified_name)
84-
85-
86-
# TODO: remove after packaging issue is resolved.
87-
class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]):
88-
def __init__(
89-
self,
90-
input_tensors: List[torch.Tensor],
91-
pg: dist.ProcessGroup,
92-
) -> None:
93-
super().__init__()
94-
self.num_workers: int = pg.size()
95-
96-
with record_function("## all2all_data:kjt splits ##"):
97-
self._output_tensor: torch.Tensor = torch.empty(
98-
[self.num_workers * len(input_tensors)],
99-
device=input_tensors[0].device,
100-
dtype=input_tensors[0].dtype,
101-
)
102-
input_tensor = torch.stack(input_tensors, dim=1).flatten()
103-
self._splits_awaitable: dist.Work = dist.all_to_all_single(
104-
output=self._output_tensor,
105-
input=input_tensor,
106-
group=pg,
107-
async_op=True,
108-
)
109-
110-
def _wait_impl(self) -> List[List[int]]:
111-
self._splits_awaitable.wait()
112-
return self._output_tensor.view(self.num_workers, -1).T.tolist()
113-
114-
115-
# TODO: remove after packaging issue is resolved.
116-
C = TypeVar("C", bound=Multistreamable)
117-
T = TypeVar("T")
118-
119-
120-
# TODO: remove after packaging issue is resolved.
121-
def _set_sharding_context_intra_a2a(
122-
tensors_awaitables: List[Awaitable[KeyedJaggedTensor]],
123-
ctx: C,
124-
) -> None:
125-
for awaitable, sharding_context in zip(
126-
tensors_awaitables,
127-
getattr(ctx, "sharding_contexts", []),
128-
):
129-
if isinstance(awaitable, KJTAllToAllTensorsAwaitable):
130-
if hasattr(sharding_context, "input_splits"):
131-
sharding_context.input_splits = awaitable._input_splits["values"]
132-
if hasattr(sharding_context, "output_splits"):
133-
sharding_context.output_splits = awaitable._output_splits["values"]
134-
if hasattr(sharding_context, "sparse_features_recat"):
135-
sharding_context.sparse_features_recat = awaitable._recat
136-
if (
137-
hasattr(sharding_context, "batch_size_per_rank")
138-
and awaitable._stride_per_rank is not None
139-
):
140-
sharding_context.batch_size_per_rank = awaitable._stride_per_rank
141-
142-
143-
# TODO: remove after packaging issue is resolved.
144-
@dataclass
145-
class KJTSplitsAllToAllMeta:
146-
pg: dist.ProcessGroup
147-
_input: KeyedJaggedTensor
148-
splits: List[int]
149-
splits_tensors: List[torch.Tensor]
150-
input_splits: List[List[int]]
151-
input_tensors: List[torch.Tensor]
152-
labels: List[str]
153-
keys: List[str]
154-
device: torch.device
155-
stagger: int
156-
157-
158-
# TODO: remove after packaging issue is resolved.
159-
def _split(flat_list: List[T], splits: List[int]) -> List[List[T]]:
160-
return [
161-
flat_list[sum(splits[:i]) : sum(splits[:i]) + n] for i, n in enumerate(splits)
162-
]
163-
164-
165-
# TODO: remove after packaging issue is resolved.
166-
class FusedKJTListSplitsAwaitable(Awaitable[List[KJTListAwaitable]]):
167-
def __init__(
168-
self,
169-
requests: List[KJTListSplitsAwaitable[C]],
170-
contexts: List[C],
171-
pg: Optional[dist.ProcessGroup],
172-
) -> None:
173-
super().__init__()
174-
self._contexts = contexts
175-
self._awaitables: List[
176-
Union[KJTSplitsAllToAllMeta, Awaitable[Awaitable[KeyedJaggedTensor]]]
177-
] = [awaitable for request in requests for awaitable in request.awaitables]
178-
self._output_lengths: List[int] = [
179-
len(request.awaitables) for request in requests
180-
]
181-
self._lengths: List[int] = [
182-
(
183-
len(awaitable.splits_tensors)
184-
if isinstance(awaitable, KJTSplitsAllToAllMeta)
185-
else 0
186-
)
187-
for awaitable in self._awaitables
188-
]
189-
splits_tensors = [
190-
splits_tensor
191-
for awaitable in self._awaitables
192-
for splits_tensor in (
193-
awaitable.splits_tensors
194-
if isinstance(awaitable, KJTSplitsAllToAllMeta)
195-
else []
196-
)
197-
]
198-
self._splits_awaitable: Optional[SplitsAllToAllAwaitable] = (
199-
SplitsAllToAllAwaitable(
200-
input_tensors=splits_tensors,
201-
pg=pg,
202-
)
203-
if splits_tensors and pg is not None
204-
else None
205-
)
206-
207-
def _wait_impl(self) -> List[KJTListAwaitable]:
208-
if self._splits_awaitable:
209-
splits_list = self._splits_awaitable.wait()
210-
splits_per_awaitable = _split(splits_list, self._lengths)
211-
else:
212-
splits_per_awaitable = [[] for _ in range(len(self._lengths))]
213-
tensors_awaitables = []
214-
for splits, awaitable in zip(splits_per_awaitable, self._awaitables):
215-
if not splits: # NoWait
216-
assert isinstance(awaitable, Awaitable)
217-
tensors_awaitables.append(awaitable.wait())
218-
continue
219-
assert isinstance(awaitable, KJTSplitsAllToAllMeta)
220-
if awaitable._input.variable_stride_per_key():
221-
output_splits = splits
222-
stride_per_rank = None
223-
else:
224-
output_splits = splits[:-1]
225-
stride_per_rank = splits[-1]
226-
tensors_awaitables.append(
227-
KJTAllToAllTensorsAwaitable(
228-
pg=awaitable.pg,
229-
input=awaitable._input,
230-
splits=awaitable.splits,
231-
input_splits=awaitable.input_splits,
232-
output_splits=output_splits,
233-
input_tensors=awaitable.input_tensors,
234-
labels=awaitable.labels,
235-
keys=awaitable.keys,
236-
device=awaitable.device,
237-
stagger=awaitable.stagger,
238-
stride_per_rank=stride_per_rank,
239-
)
240-
)
241-
output = []
242-
awaitables_per_output = _split(tensors_awaitables, self._output_lengths)
243-
for awaitables, ctx in zip(awaitables_per_output, self._contexts):
244-
_set_sharding_context_intra_a2a(awaitables, ctx)
245-
output.append(KJTListAwaitable(awaitables, ctx))
246-
return output
247-
248-
24962
@dataclass
25063
class TrainPipelineContext:
25164
"""
@@ -462,6 +275,30 @@ def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta:
462275
)
463276

464277

278+
class Tracer(torch.fx.Tracer):
279+
"""
280+
Disables proxying buffers during tracing. Ideally, proxying buffers would be
281+
disabled, but some models are currently mutating buffer values, which causes errors
282+
during tracing. If those models can be rewritten to not do that, we can likely
283+
remove this line.
284+
"""
285+
286+
proxy_buffer_attributes = False
287+
288+
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
289+
super().__init__()
290+
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
291+
292+
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
293+
if (
294+
isinstance(m, ShardedModule)
295+
or module_qualified_name in self._leaf_modules
296+
or isinstance(m, FSDP)
297+
):
298+
return True
299+
return super().is_leaf_module(m, module_qualified_name)
300+
301+
465302
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
466303
assert isinstance(
467304
batch, (torch.Tensor, Pipelineable)

0 commit comments

Comments
 (0)