Skip to content

Commit

Permalink
enable backward pass computation and communication overlap by prefetc…
Browse files Browse the repository at this point in the history
…hing all gather (pytorch#70235)

Summary:
Pull Request resolved: pytorch#70235

address comments in pytorch#69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.

After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.

So putting  these two algorithms as two configurable experimental algorithms for now

prefetch full parameters at pre-backward hook:

It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.

To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.

The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.

In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.

Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D33252795

fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
  • Loading branch information
zhaojuanmao authored and facebook-github-bot committed Dec 23, 2021
1 parent 1d09458 commit b15212c
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 97 deletions.
60 changes: 48 additions & 12 deletions test/distributed/fsdp/test_fsdp_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from torch.distributed._fsdp import CPUOffload
from torch.distributed._fsdp.fully_sharded_data_parallel import BackwardPrefetch_


if not dist.is_available():
Expand Down Expand Up @@ -66,53 +67,72 @@ def _get_init_modes_for_test(self, cpu_offload):
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_nested_wrapped_model(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_nested_wrapped_model(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
self._test_identical_outputs(
NestedWrappedModule,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_nested_all_wrapped_model(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
self._test_identical_outputs(
model_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_transformer_parameterized(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_transformer_parameterized(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
self._test_identical_outputs(
TransformerWithSharedParams,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_delayed_optim_step(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_delayed_optim_step(self, cpu_offload, backward_prefetch):
# We use a model with a long CUDA delay right before the optimizer step.
# This tests our streams logic, and that we don't start the allgather
# until after the optimization step completes.
Expand All @@ -125,15 +145,20 @@ def test_delayed_optim_step(self, cpu_offload):
self._test_identical_outputs(
model_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_delayed_reduce_scatter(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch):
# We insert a delay in the torch.distributed._reduce_scatter_base op, so that
# the post_backward_stream takes much longer than the backward pass.
# This tests that we properly block at the end of the backward pass for
Expand All @@ -147,7 +172,8 @@ def test_delayed_reduce_scatter(self, cpu_offload):
self._test_identical_outputs(
model_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)

def _dummy_ddp_fn(self, model):
Expand All @@ -158,7 +184,11 @@ def _dummy_ddp_fn(self, model):
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_mixture_of_experts(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_mixture_of_experts(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
Expand All @@ -169,14 +199,19 @@ def test_mixture_of_experts(self, cpu_offload):
ref_ddp_fn=self._dummy_ddp_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
)
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload):
@parametrize(
"backward_prefetch",
[BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None]
)
def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
Expand All @@ -186,6 +221,7 @@ def test_mixture_of_experts_with_delay_before_free(self, cpu_offload):
ref_ddp_fn=self._dummy_ddp_fn,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
)


Expand Down
75 changes: 0 additions & 75 deletions test/distributed/fsdp/test_fsdp_param_mutation.py

This file was deleted.

Loading

0 comments on commit b15212c

Please sign in to comment.