Skip to content

Commit e50f7c5

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3a7792c commit e50f7c5

File tree

2 files changed

+171
-96
lines changed

2 files changed

+171
-96
lines changed

pippy/PipelineSchedule.py

+97-96
Original file line numberDiff line numberDiff line change
@@ -603,27 +603,22 @@ def step_microbatches(
603603
losses: Optional[List] = None,
604604
):
605605
"""
606-
# n_loop = n_stage / n_pp
607-
# run microbatches in sequences of NPp
606+
Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf).
608607
609-
schedule operates at the rank level
610-
611-
highest rank has a warmup (F only) count of [len(stages) - 1] * seq_size
612-
each hop away from highest rank adds 2 warmup stages
608+
Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks
609+
and each rank away from highest rank adds 2 warmup steps due to:
613610
- one happened before highest rank's warmup started,
614611
- one waiting for backward result to trickle down from highest rank
615-
dist_from_highest = (worldsize - 1) - rank
616-
617-
total_steps = warmup_steps + (num_stages * num_microbatch)
618612
619-
Rank 0: 0F 0F 0F 0F 2F 2F 2F 2F
620-
Rank 1: 1F 1F 1F 1F 3F3B 3F 3F 3F
613+
TODO: Interleaved 1F1B does not support using sorted_batch_isend_irecv()
614+
because it requires recvs and sends from different peers
615+
to execute in the same coalesced operation. As a result, this schedule does
616+
not support models with skip connections.
621617
"""
622618
arg_mbs, kwarg_mbs = self._check_inputs(
623619
arg_mbs, kwarg_mbs, target_mbs, losses
624620
)
625621

626-
# warmup steps for latest pp stage is trivial to compute
627622
# increment warmup_steps by 2 for each hop away
628623
warmup_steps = (self.n_local_stages - 1) * self.pp_group_size
629624
warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank)
@@ -641,7 +636,7 @@ def step_microbatches(
641636
warmup_steps + fwd_bwd_steps * 2 + cooldown_steps
642637
== self.n_local_stages * self._n_microbatches * 2
643638
)
644-
self.total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
639+
total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
645640

646641
logger.debug(
647642
f"""
@@ -669,104 +664,110 @@ def backward_stage_local_index(step):
669664
# Delay send waits
670665
sends_to_wait: List[dist.Work] = []
671666

672-
for step in range(self.total_steps):
673-
# warmup, forward only
674-
if step < warmup_steps:
675-
logger.debug(f"{forward_stage_local_index(step)=}")
667+
# Store ops (potentially across steps)
668+
ops: List[dist.P2POp] = []
676669

677-
fwd_stage = self._stages[forward_stage_local_index(step)]
678-
# assigns the current microbatch index and updates it for future steps
679-
fwd_stage_mb_index[fwd_stage] = (
680-
mb_index := fwd_stage_mb_index[fwd_stage]
681-
) + 1
670+
# Warmup Phase (forward only)
671+
for step in range(warmup_steps):
672+
fwd_stage = self._stages[forward_stage_local_index(step)]
682673

683-
logger.debug(
684-
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}"
685-
)
674+
# This will assign the current microbatch index and update it for future steps
675+
fwd_stage_mb_index[fwd_stage] = (
676+
mb_index := fwd_stage_mb_index[fwd_stage]
677+
) + 1
686678

687-
with record_function(f"Forward {step}"):
688-
ops = fwd_stage.get_fwd_recv_ops()
689-
works = sorted_batch_isend_irecv(ops)
690-
for work in works.values():
691-
work.wait()
679+
logger.debug(
680+
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}"
681+
)
692682

693-
output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]
683+
with record_function(f"Forward {step}"):
684+
ops.extend(fwd_stage.get_fwd_recv_ops())
685+
if ops:
686+
work = dist.batch_isend_irecv(ops).pop()
687+
work.wait()
688+
ops.clear()
694689

695-
ops = fwd_stage.get_fwd_send_ops()
696-
works = sorted_batch_isend_irecv(ops)
697-
sends_to_wait.extend(works.values())
690+
output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]
698691

699-
self._maybe_compute_loss(
700-
fwd_stage, output, target_mbs, mb_index
701-
)
692+
ops.extend(fwd_stage.get_fwd_send_ops())
693+
# If we are right before the fwd-bwd step, then we need to delay the send to the next step,
694+
# This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang.
695+
# In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed
696+
if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0):
697+
work = dist.batch_isend_irecv(ops).pop()
698+
sends_to_wait.append(work)
699+
ops.clear()
702700

703-
# 1f1b
704-
elif warmup_steps <= step < warmup_steps + fwd_bwd_steps:
705-
logger.debug(f"{forward_stage_local_index(step)=}")
706-
logger.debug(f"{backward_stage_local_index(step)=}")
701+
self._maybe_compute_loss(
702+
fwd_stage, output, target_mbs, mb_index
703+
)
707704

708-
fwd_stage = self._stages[forward_stage_local_index(step)]
709-
bwd_stage = self._stages[backward_stage_local_index(step)]
705+
# 1F1B Phase (forward and backward)
706+
for step in range(warmup_steps, warmup_steps + fwd_bwd_steps):
707+
fwd_stage = self._stages[forward_stage_local_index(step)]
708+
bwd_stage = self._stages[backward_stage_local_index(step)]
710709

711-
fwd_stage_mb_index[fwd_stage] = (
712-
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
713-
) + 1
714-
bwd_stage_mb_index[bwd_stage] = (
715-
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
716-
) + 1
710+
fwd_stage_mb_index[fwd_stage] = (
711+
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
712+
) + 1
713+
bwd_stage_mb_index[bwd_stage] = (
714+
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
715+
) + 1
717716

718-
bwd_stage._configure_data_parallel_mode(
719-
bwd_mb_index == self._n_microbatches - 1
720-
)
721-
logger.debug(
722-
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
723-
)
724-
with record_function(f"1F1B {step}"):
725-
ops = fwd_stage.get_fwd_recv_ops()
726-
ops.extend(bwd_stage.get_bwd_recv_ops())
727-
works = sorted_batch_isend_irecv(ops)
728-
for work in works.values():
729-
work.wait()
717+
bwd_stage._configure_data_parallel_mode(
718+
bwd_mb_index == self._n_microbatches - 1
719+
)
720+
logger.debug(
721+
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
722+
)
723+
with record_function(f"1F1B {step}"):
724+
ops.extend(fwd_stage.get_fwd_recv_ops())
725+
ops.extend(bwd_stage.get_bwd_recv_ops())
726+
if ops:
727+
work = dist.batch_isend_irecv(ops).pop()
728+
work.wait()
729+
ops.clear()
730730

731-
# fwd
732-
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
733-
ops = fwd_stage.get_fwd_send_ops()
734-
self._maybe_compute_loss(
735-
fwd_stage, output, target_mbs, fwd_mb_index
736-
)
731+
# Forward
732+
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
733+
ops.extend(fwd_stage.get_fwd_send_ops())
734+
self._maybe_compute_loss(
735+
fwd_stage, output, target_mbs, fwd_mb_index
736+
)
737737

738-
# bwd
739-
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
740-
bwd_stage.backward_one_chunk(loss=loss)
741-
ops.extend(bwd_stage.get_bwd_send_ops())
738+
# Backward
739+
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
740+
bwd_stage.backward_one_chunk(loss=loss)
741+
ops.extend(bwd_stage.get_bwd_send_ops())
742+
743+
# Cooldown Phase (backward only)
744+
for step in range(warmup_steps + fwd_bwd_steps, total_steps):
745+
bwd_stage = self._stages[backward_stage_local_index(step)]
746+
bwd_stage_mb_index[bwd_stage] = (
747+
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
748+
) + 1
749+
bwd_stage._configure_data_parallel_mode(
750+
bwd_mb_index == self._n_microbatches - 1
751+
)
742752

743-
works = sorted_batch_isend_irecv(ops)
744-
sends_to_wait.extend(works.values())
745-
746-
# cooldown
747-
else:
748-
bwd_stage = self._stages[backward_stage_local_index(step)]
749-
bwd_stage_mb_index[bwd_stage] = (
750-
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
751-
) + 1
752-
bwd_stage._configure_data_parallel_mode(
753-
bwd_mb_index == self._n_microbatches - 1
754-
)
755-
logger.debug(
756-
f"{self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}"
757-
)
758-
with record_function(f"Cooldown (backward) {step}"):
759-
ops = bwd_stage.get_bwd_recv_ops()
760-
works = sorted_batch_isend_irecv(ops)
761-
for work in works.values():
762-
work.wait()
753+
logger.debug(
754+
f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}"
755+
)
756+
with record_function(f"Cooldown {step}"):
757+
ops.extend(bwd_stage.get_bwd_recv_ops())
758+
if ops:
759+
work = dist.batch_isend_irecv(ops).pop()
760+
work.wait()
761+
ops.clear()
763762

764-
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
765-
bwd_stage.backward_one_chunk(loss=loss)
763+
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
764+
bwd_stage.backward_one_chunk(loss=loss)
766765

767-
ops = bwd_stage.get_bwd_send_ops()
768-
works = sorted_batch_isend_irecv(ops)
769-
sends_to_wait.extend(works.values())
766+
ops.extend(bwd_stage.get_bwd_send_ops())
767+
if ops:
768+
work = dist.batch_isend_irecv(ops).pop()
769+
sends_to_wait.append(work)
770+
ops.clear()
770771

771772
# Make sure all sends are finished
772773
for work in sends_to_wait:

test/test_pipeline_schedule.py

+74
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

3+
import random
4+
import time
35
import unittest
46

57
import torch
@@ -91,6 +93,37 @@ def forward(self, x):
9193
return {}
9294

9395

96+
class ModelWithSleep(nn.Module):
97+
def __init__(
98+
self,
99+
dim: int,
100+
hidden_dim: int,
101+
out_dim: int,
102+
rank: int,
103+
):
104+
super().__init__()
105+
self.in_layer = nn.Linear(dim, hidden_dim, bias=False)
106+
self.middle = nn.Sequential(
107+
nn.Linear(hidden_dim, hidden_dim, bias=False),
108+
nn.ReLU(),
109+
nn.Linear(hidden_dim, hidden_dim, bias=False),
110+
nn.ReLU(),
111+
)
112+
self.out_layer = nn.Linear(hidden_dim, out_dim, bias=False)
113+
self.relu = nn.ReLU()
114+
self.rank = rank
115+
116+
def forward(self, x):
117+
x = self.in_layer(x)
118+
x = self.middle(x)
119+
# this delay helps to simulate inconsistencies in timing between ranks
120+
if self.rank == 0 or self.rank == 1:
121+
time.sleep(random.uniform(0, 0.5))
122+
x = self.out_layer(x)
123+
x = self.relu(x)
124+
return x
125+
126+
94127
# Tests defined below
95128
##########################
96129

@@ -400,6 +433,47 @@ def test_interleaved_1f1b_negative(self):
400433
]
401434
schedule.step_microbatches(microbatches)
402435

436+
@skip_if_lt_x_gpu(4)
437+
def test_interleaved_1f1b_with_model_sleep(self):
438+
device = torch.device(f"cuda:{self.rank}")
439+
dist.init_process_group(
440+
init_method=self.init_method,
441+
backend="nccl",
442+
rank=self.rank,
443+
world_size=self.world_size,
444+
)
445+
446+
num_dims = 4
447+
model = ModelWithSleep(
448+
dim=num_dims, hidden_dim=8, out_dim=num_dims, rank=self.rank
449+
)
450+
stages_per_rank = 2
451+
num_microbatches_list = [4, 8, 16]
452+
for num_microbatches in num_microbatches_list:
453+
batch = torch.rand((num_microbatches, num_dims), device=device)
454+
stages = self._create_virtual_pipeline_stages(
455+
model,
456+
torch.rand((1, num_dims)).to("meta"),
457+
device,
458+
stages_per_rank,
459+
num_microbatches=num_microbatches,
460+
)
461+
462+
schedule = ScheduleInterleaved1F1B(
463+
stages, num_microbatches, loss_fn=nn.MSELoss()
464+
)
465+
if self.rank == 0:
466+
schedule.step(batch)
467+
elif self.rank == self.world_size - 1:
468+
target = torch.rand((num_microbatches, num_dims), device=device)
469+
losses = []
470+
schedule.step(target=target, losses=losses)
471+
else:
472+
schedule.step()
473+
dist.barrier()
474+
torch.cuda.synchronize()
475+
print(f"Finished with testing {num_microbatches} microbatches")
476+
403477

404478
class UtilTest(unittest.TestCase):
405479
def test_metadata_tensor(self):

0 commit comments

Comments
 (0)