Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interleaved 1f1b race #1098

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Update
[ghstack-poisoned]
  • Loading branch information
H-Huang committed May 1, 2024
commit e50f7c55752af96fb1d2b7a6c8f118e84289a8fb
193 changes: 97 additions & 96 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,27 +603,22 @@ def step_microbatches(
losses: Optional[List] = None,
):
"""
# n_loop = n_stage / n_pp
# run microbatches in sequences of NPp
Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf).

schedule operates at the rank level

highest rank has a warmup (F only) count of [len(stages) - 1] * seq_size
each hop away from highest rank adds 2 warmup stages
Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks
and each rank away from highest rank adds 2 warmup steps due to:
- one happened before highest rank's warmup started,
- one waiting for backward result to trickle down from highest rank
dist_from_highest = (worldsize - 1) - rank

total_steps = warmup_steps + (num_stages * num_microbatch)

Rank 0: 0F 0F 0F 0F 2F 2F 2F 2F
Rank 1: 1F 1F 1F 1F 3F3B 3F 3F 3F
TODO: Interleaved 1F1B does not support using sorted_batch_isend_irecv()
because it requires recvs and sends from different peers
to execute in the same coalesced operation. As a result, this schedule does
not support models with skip connections.
"""
arg_mbs, kwarg_mbs = self._check_inputs(
arg_mbs, kwarg_mbs, target_mbs, losses
)

# warmup steps for latest pp stage is trivial to compute
# increment warmup_steps by 2 for each hop away
warmup_steps = (self.n_local_stages - 1) * self.pp_group_size
warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank)
Expand All @@ -641,7 +636,7 @@ def step_microbatches(
warmup_steps + fwd_bwd_steps * 2 + cooldown_steps
== self.n_local_stages * self._n_microbatches * 2
)
self.total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps

logger.debug(
f"""
Expand Down Expand Up @@ -669,104 +664,110 @@ def backward_stage_local_index(step):
# Delay send waits
sends_to_wait: List[dist.Work] = []

for step in range(self.total_steps):
# warmup, forward only
if step < warmup_steps:
logger.debug(f"{forward_stage_local_index(step)=}")
# Store ops (potentially across steps)
ops: List[dist.P2POp] = []

fwd_stage = self._stages[forward_stage_local_index(step)]
# assigns the current microbatch index and updates it for future steps
fwd_stage_mb_index[fwd_stage] = (
mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
# Warmup Phase (forward only)
for step in range(warmup_steps):
fwd_stage = self._stages[forward_stage_local_index(step)]

logger.debug(
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}"
)
# This will assign the current microbatch index and update it for future steps
fwd_stage_mb_index[fwd_stage] = (
mb_index := fwd_stage_mb_index[fwd_stage]
) + 1

with record_function(f"Forward {step}"):
ops = fwd_stage.get_fwd_recv_ops()
works = sorted_batch_isend_irecv(ops)
for work in works.values():
work.wait()
logger.debug(
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}"
)

output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]
with record_function(f"Forward {step}"):
ops.extend(fwd_stage.get_fwd_recv_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
work.wait()
ops.clear()

ops = fwd_stage.get_fwd_send_ops()
works = sorted_batch_isend_irecv(ops)
sends_to_wait.extend(works.values())
output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]

self._maybe_compute_loss(
fwd_stage, output, target_mbs, mb_index
)
ops.extend(fwd_stage.get_fwd_send_ops())
# If we are right before the fwd-bwd step, then we need to delay the send to the next step,
# This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang.
# In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed
if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0):
work = dist.batch_isend_irecv(ops).pop()
sends_to_wait.append(work)
ops.clear()

# 1f1b
elif warmup_steps <= step < warmup_steps + fwd_bwd_steps:
logger.debug(f"{forward_stage_local_index(step)=}")
logger.debug(f"{backward_stage_local_index(step)=}")
self._maybe_compute_loss(
fwd_stage, output, target_mbs, mb_index
)

fwd_stage = self._stages[forward_stage_local_index(step)]
bwd_stage = self._stages[backward_stage_local_index(step)]
# 1F1B Phase (forward and backward)
for step in range(warmup_steps, warmup_steps + fwd_bwd_steps):
fwd_stage = self._stages[forward_stage_local_index(step)]
bwd_stage = self._stages[backward_stage_local_index(step)]

fwd_stage_mb_index[fwd_stage] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
fwd_stage_mb_index[fwd_stage] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1

bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
)
with record_function(f"1F1B {step}"):
ops = fwd_stage.get_fwd_recv_ops()
ops.extend(bwd_stage.get_bwd_recv_ops())
works = sorted_batch_isend_irecv(ops)
for work in works.values():
work.wait()
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
)
with record_function(f"1F1B {step}"):
ops.extend(fwd_stage.get_fwd_recv_ops())
ops.extend(bwd_stage.get_bwd_recv_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
work.wait()
ops.clear()

# fwd
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
ops = fwd_stage.get_fwd_send_ops()
self._maybe_compute_loss(
fwd_stage, output, target_mbs, fwd_mb_index
)
# Forward
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
ops.extend(fwd_stage.get_fwd_send_ops())
self._maybe_compute_loss(
fwd_stage, output, target_mbs, fwd_mb_index
)

# bwd
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())
# Backward
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())

# Cooldown Phase (backward only)
for step in range(warmup_steps + fwd_bwd_steps, total_steps):
bwd_stage = self._stages[backward_stage_local_index(step)]
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)

works = sorted_batch_isend_irecv(ops)
sends_to_wait.extend(works.values())

# cooldown
else:
bwd_stage = self._stages[backward_stage_local_index(step)]
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"{self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}"
)
with record_function(f"Cooldown (backward) {step}"):
ops = bwd_stage.get_bwd_recv_ops()
works = sorted_batch_isend_irecv(ops)
for work in works.values():
work.wait()
logger.debug(
f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}"
)
with record_function(f"Cooldown {step}"):
ops.extend(bwd_stage.get_bwd_recv_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
work.wait()
ops.clear()

loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)

ops = bwd_stage.get_bwd_send_ops()
works = sorted_batch_isend_irecv(ops)
sends_to_wait.extend(works.values())
ops.extend(bwd_stage.get_bwd_send_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
sends_to_wait.append(work)
ops.clear()

# Make sure all sends are finished
for work in sends_to_wait:
Expand Down
74 changes: 74 additions & 0 deletions test/test_pipeline_schedule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import random
import time
import unittest

import torch
Expand Down Expand Up @@ -91,6 +93,37 @@ def forward(self, x):
return {}


class ModelWithSleep(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
out_dim: int,
rank: int,
):
super().__init__()
self.in_layer = nn.Linear(dim, hidden_dim, bias=False)
self.middle = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.ReLU(),
)
self.out_layer = nn.Linear(hidden_dim, out_dim, bias=False)
self.relu = nn.ReLU()
self.rank = rank

def forward(self, x):
x = self.in_layer(x)
x = self.middle(x)
# this delay helps to simulate inconsistencies in timing between ranks
if self.rank == 0 or self.rank == 1:
time.sleep(random.uniform(0, 0.5))
x = self.out_layer(x)
x = self.relu(x)
return x


# Tests defined below
##########################

Expand Down Expand Up @@ -400,6 +433,47 @@ def test_interleaved_1f1b_negative(self):
]
schedule.step_microbatches(microbatches)

@skip_if_lt_x_gpu(4)
def test_interleaved_1f1b_with_model_sleep(self):
device = torch.device(f"cuda:{self.rank}")
dist.init_process_group(
init_method=self.init_method,
backend="nccl",
rank=self.rank,
world_size=self.world_size,
)

num_dims = 4
model = ModelWithSleep(
dim=num_dims, hidden_dim=8, out_dim=num_dims, rank=self.rank
)
stages_per_rank = 2
num_microbatches_list = [4, 8, 16]
for num_microbatches in num_microbatches_list:
batch = torch.rand((num_microbatches, num_dims), device=device)
stages = self._create_virtual_pipeline_stages(
model,
torch.rand((1, num_dims)).to("meta"),
device,
stages_per_rank,
num_microbatches=num_microbatches,
)

schedule = ScheduleInterleaved1F1B(
stages, num_microbatches, loss_fn=nn.MSELoss()
)
if self.rank == 0:
schedule.step(batch)
elif self.rank == self.world_size - 1:
target = torch.rand((num_microbatches, num_dims), device=device)
losses = []
schedule.step(target=target, losses=losses)
else:
schedule.step()
dist.barrier()
torch.cuda.synchronize()
print(f"Finished with testing {num_microbatches} microbatches")


class UtilTest(unittest.TestCase):
def test_metadata_tensor(self):
Expand Down
Loading