Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import abc
import contextlib
import logging
import threading
from collections import deque
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -469,6 +470,9 @@ def __init__(
dmp_collection_sync_interval_batches: Optional[int] = 1,
enqueue_batch_after_forward: bool = False,
inplace_copy_batch_to_gpu: bool = False,
# To overcome a host bottleneck, enable 'pipeline_thread' option for pipeline overlapping.
# Enabling this option starts an additional thread.
pipeline_thread: bool = False,
) -> None:
self._model = model
self._optimizer = optimizer
Expand Down Expand Up @@ -549,6 +553,17 @@ def __init__(
self._batch_ip2: Optional[In] = None
self._context: TrainPipelineContext = context_type(version=0)

# parallel pipeline
self.pipeline_thread = pipeline_thread
if self.pipeline_thread:
self.helper_thread = threading.Thread(
target=self.progress_helper, daemon=True
)
self.helper_go = threading.Event()
self.helper_done = threading.Event()
self._cur_dliter = None
self.helper_thread.start()

def detach(self) -> torch.nn.Module:
"""
Detaches the model from sparse data dist (SDD) pipeline. A user might want to get
Expand Down Expand Up @@ -692,6 +707,30 @@ def _backward(self, losses: torch.Tensor) -> None:
with record_function(f"## backward {batch_id} ##"):
torch.sum(losses, dim=0).backward()

def progress_helper(self):
while True:
self.helper_go.wait()
if self.helper_go.is_set():
self.helper_go.clear()
# preprocess next context
self.pipeline_prepare()
self.helper_done.set()

def pipeline_prepare(self):
self.fill_pipeline(self._cur_dliter)
if not self.batches:
return False
self._wait_for_batch()
if len(self.batches) >= 2:
# invoke splits all_to_all comms (first part of input_dist)
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
self.enqueue_batch(self._cur_dliter)

if len(self.batches) >= 2:
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
self.wait_sparse_data_dist(self.contexts[1])
return True

def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
Expand All @@ -706,46 +745,29 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._model_attached:
self.attach(self._model)

# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
self.fill_pipeline(dataloader_iter)

# here is the expected stop after exhausting all batches
if not self.batches:
raise StopIteration
self._cur_dliter = dataloader_iter
# get first context in mainthread
if len(self.contexts) == 0 or not self.pipeline_thread:
if not self.pipeline_prepare():
raise StopIteration

# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
self._set_module_context(self.contexts[0])
cur_batch = self.batches.popleft()
cur_context = self.contexts.popleft()
self._set_module_context(cur_batch)

if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()

# wait for batches[0] being available on device, this should always be completed since
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
self._wait_for_batch()

if len(self.batches) >= 2:
# invoke splits all_to_all comms (first part of input_dist)
self.start_sparse_data_dist(self.batches[1], self.contexts[1])

if not self._enqueue_batch_after_forward:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
self.enqueue_batch(dataloader_iter)

# forward
with record_function(f"## forward {self.contexts[0].index} ##"):
with record_function(f"## forward {cur_context.index} ##"):
self._state = PipelineState.CALL_FWD
losses, output = self._model_fwd(self.batches[0])

if self._enqueue_batch_after_forward:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
self.enqueue_batch(dataloader_iter)
losses, output = self._model_fwd(cur_batch)

if len(self.batches) >= 2:
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
self.wait_sparse_data_dist(self.contexts[1])
# run helper thread after forward
if self.pipeline_thread:
self.helper_go.set()

if self._model.training:
# backward
Expand All @@ -755,14 +777,19 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.sync_embeddings(
self._model,
self._dmp_collection_sync_interval_batches,
self.contexts[0],
cur_context,
)

# update
with record_function(f"## optimizer {self.contexts[0].index} ##"):
self._optimizer.step()

self.dequeue_batch()
if self.pipeline_thread:
self.helper_done.wait()
self.helper_done.clear()
# update PipelinedForward context to match next forward pass
if len(self.batches) >= 1:
self._set_module_context(self.contexts[0])
return output

def _create_context(self) -> TrainPipelineContext:
Expand Down