Skip to content

split train_pipeline.utils - forward #2995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 301 additions & 0 deletions torchrec/distributed/train_pipeline/runtime_forwards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import itertools
import logging
from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union

import torch
from torch import distributed as dist
from torch.profiler import record_function

from torchrec.distributed.embedding_sharding import KJTSplitsAllToAllMeta
from torchrec.distributed.model_parallel import ShardedModule
from torchrec.distributed.train_pipeline.pipeline_context import (
EmbeddingTrainPipelineContext,
PrefetchTrainPipelineContext,
TrainPipelineContext,
)
from torchrec.distributed.train_pipeline.types import CallArgs
from torchrec.distributed.types import Awaitable, LazyNoWait
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Multistreamable

logger: logging.Logger = logging.getLogger(__name__)

TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext)

EmbeddingModuleRetType = Union[Dict[str, JaggedTensor], KeyedTensor]


class BaseForward(Generic[TForwardContext]):
def __init__(
self,
name: str,
args: CallArgs,
module: ShardedModule,
context: TForwardContext,
stream: Optional[torch.Stream] = None,
) -> None:
self._name = name
self._args = args
self._module = module
self._context = context
self._stream = stream
self._device: torch.device = stream.device if stream else torch.device("cuda")

@property
def name(self) -> str:
return self._name

@property
def args(self) -> CallArgs:
return self._args

def set_context(self, context: TForwardContext) -> None:
self._context = context

def get_context(self) -> TForwardContext:
return self._context


class PipelinedForward(BaseForward[TrainPipelineContext]):
"""
This pipeline is used in TrainPipelineSparseDist
"""

# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert (
self._name in self._context.input_dist_tensors_requests
), "Invalid PipelinedForward usage, please do not directly call model.forward()"
request = self._context.input_dist_tensors_requests.pop(self._name)
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with torch.get_device_module(self._device).stream(self._stream):
data = request.wait()

# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
ctx = self._context.module_contexts.pop(self._name)

if self._stream is not None:
torch.get_device_module(self._device).current_stream().wait_stream(
self._stream
)
cur_stream = torch.get_device_module(self._device).current_stream()

assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)
ctx.record_stream(cur_stream)

return self._module.compute_and_output_dist(ctx, data)


class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]):
"""
This pipeline is used in TrainPipelineSemiSync
"""

def __call__(
self,
# pyre-ignore
*input,
# pyre-ignore
**kwargs,
) -> Union[
Awaitable[EmbeddingModuleRetType],
Tuple[
Awaitable[EmbeddingModuleRetType], Awaitable[Optional[KeyedJaggedTensor]]
],
]:
assert (
self._name in self._context.embedding_a2a_requests
), "Invalid EmbeddingPipelinedForward usage, please do not directly call model.forward()"

ctx = self._context.module_contexts.pop(self._name)
cur_stream = torch.get_device_module(self._device).current_stream()

if self._stream is not None:
torch.get_device_module(self._device).current_stream().wait_stream(
self._stream
)
ctx.record_stream(cur_stream)

awaitable = self._context.embedding_a2a_requests.pop(self._name)
# in case of MC modules
is_mc_module: bool = isinstance(awaitable, Iterable)
remapped_kjts: Optional[KeyedJaggedTensor] = None

if is_mc_module:
embeddings = awaitable[0].wait()
remapped_kjts = awaitable[1].wait()
else:
assert isinstance(awaitable, Awaitable)
embeddings = (
awaitable.wait()
) # trigger awaitable manually for type checking

self.detach_embeddings(embeddings=embeddings, cur_stream=cur_stream)

if is_mc_module:
return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts))
else:
return LazyNoWait(embeddings)

def detach_embeddings(
self,
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
cur_stream: torch.Stream,
) -> None:
"""
detach the grad from embeddings so that the backward/opt of the embeddings
won't be invoked by loss.backward(). Instead, there is a dedicated embedding_backward
call in semi-sync pipeline progress.
"""
tensors = []
detached_tensors = []
# in case of EC, embeddings are Dict[str, JaggedTensor]
if isinstance(embeddings, Dict):
for jt in embeddings.values():
assert isinstance(jt, JaggedTensor)
tensor = jt.values()
detached_tensor = tensor.detach().requires_grad_()
detached_tensor.retain_grad()
jt._values = detached_tensor
tensors.append(tensor)
detached_tensors.append(detached_tensor)
self._context.embedding_tensors.append(tensors)
self._context.embedding_features.append(list(embeddings.keys()))
self._context.detached_embedding_tensors.append(detached_tensors)
else:
# in case of EBC, embeddings are KeyedTensor
assert isinstance(embeddings, KeyedTensor)
embeddings.record_stream(cur_stream)
tensor = embeddings.values()
detached_tensor = tensor.detach().requires_grad_()
detached_tensor.retain_grad()
embeddings._values = detached_tensor
tensors.append(tensor)
detached_tensors.append(detached_tensor)
self._context.embedding_tensors.append(tensors)
"""
KeyedTensor is returned by EmbeddingBagCollections and its variants
KeyedTensor holds dense data from multiple features and .values()
returns a single concatenated dense tensor. To ensure that
context.embedding_tensors[i] has the same length as
context.embedding_features[i], we pass in a list with a single item:
a list containing all the embedding feature names.
"""
self._context.embedding_features.append([list(embeddings.keys())])
self._context.detached_embedding_tensors.append(detached_tensors)


class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward):
"""
This pipeline is used in TrainPipelineFusedSparseDist
"""

def detach_embeddings(
self,
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
cur_stream: torch.Stream,
) -> None:
# doing nothing
pass


class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
"""
This pipeline is used in PrefetchTrainPipelineSparseDist
"""

def __init__(
self,
name: str,
args: CallArgs,
module: ShardedModule,
context: PrefetchTrainPipelineContext,
prefetch_stream: Optional[torch.Stream] = None,
) -> None:
super().__init__(
name=name,
args=args,
module=module,
context=context,
stream=prefetch_stream,
)

# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert (
self._name in self._context.module_input_post_prefetch
), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()"
data = self._context.module_input_post_prefetch.pop(self._name)
ctx = self._context.module_contexts_post_prefetch.pop(self._name)

# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._stream is not None:
torch.get_device_module(self._device).current_stream().wait_stream(
self._stream
)
cur_stream = torch.get_device_module(self._device).current_stream()

assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)

ctx.record_stream(cur_stream)

return self._module.compute_and_output_dist(ctx, data)


class KJTAllToAllForward:
def __init__(
self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1
) -> None:
self._pg = pg
self._splits = splits
self._stagger = stagger
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))

def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta:
with torch.no_grad():
assert len(input.keys()) == sum(self._splits)
rank = dist.get_rank(self._pg)
local_keys = input.keys()[
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
]
input_splits = input.dist_splits(self._splits)
device = input.values().device
splits_tensors = [
torch.tensor(splits, device=device) for splits in input_splits
]
if not input.variable_stride_per_key():
splits_tensors.append(
torch.tensor([input.stride()] * self._pg.size(), device=device)
)
return KJTSplitsAllToAllMeta(
pg=self._pg,
_input=input,
splits=self._splits,
splits_tensors=splits_tensors,
input_splits=input_splits,
input_tensors=input.dist_tensors(),
labels=input.dist_labels(),
keys=local_keys,
device=device,
stagger=self._stagger,
)
40 changes: 19 additions & 21 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
)
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
from torchrec.distributed.train_pipeline.postproc import PipelinedPostproc
from torchrec.distributed.train_pipeline.runtime_forwards import (
EmbeddingPipelinedForward,
PipelinedForward,
)
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
TrainPipelineSparseDistTestBase,
)
Expand All @@ -68,11 +72,10 @@
TrainPipelineSparseDist,
TrainPipelineSparseDistCompAutograd,
)
from torchrec.distributed.train_pipeline.types import CallArgs
from torchrec.distributed.train_pipeline.utils import (
DataLoadingThread,
EmbeddingPipelinedForward,
get_h2d_func,
PipelinedForward,
PipelineStage,
SparseDataDistUtil,
StageOut,
Expand Down Expand Up @@ -1284,25 +1287,20 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
# postproc args
self.assertEqual(len(pipeline._pipelined_postprocs), 2)
# postprocs can be added in any order, so we can't assert on exact steps structures
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1)
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0)
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2)
self.assertEqual(
pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep()
)
self.assertIsInstance(
pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep
)

self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2)
self.assertEqual(
pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep()
)
self.assertIsInstance(
pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep
)
# TODO: find way not to inspect private parts
postproc1_args: CallArgs = pipeline._pipelined_postprocs[0]._args
self.assertEqual(len(postproc1_args.args), 1)
self.assertEqual(len(postproc1_args.kwargs), 0)
self.assertEqual(len(postproc1_args.args[0].steps), 2)
self.assertEqual(postproc1_args.args[0].steps[0], NoopArgInfoStep())
self.assertIsInstance(postproc1_args.args[0].steps[1], GetAttrArgInfoStep)

postproc2_args: CallArgs = pipeline._pipelined_postprocs[1]._args
self.assertEqual(len(postproc2_args.args), 1)
self.assertEqual(len(postproc2_args.kwargs), 0)
self.assertEqual(len(postproc2_args.args[0].steps), 2)
self.assertEqual(postproc2_args.args[0].steps[0], NoopArgInfoStep())
self.assertIsInstance(postproc2_args.args[0].steps[1], GetAttrArgInfoStep)

get_arg_infos = {
# pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
from torchrec.distributed.train_pipeline.runtime_forwards import PipelinedForward

from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
TrainPipelineSparseDistTestBase,
Expand All @@ -30,7 +31,7 @@
NodeArgsHelper,
PipelinedPostproc,
)
from torchrec.distributed.train_pipeline.utils import _rewrite_model, PipelinedForward
from torchrec.distributed.train_pipeline.utils import _rewrite_model
from torchrec.distributed.types import ShardingType
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand Down
Loading
Loading