Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/sequence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .layer import DistributedAttention, _SeqAllToAll # noqa: F401
33 changes: 26 additions & 7 deletions deepspeed/sequence/fpdt_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
_flash_attn_backward = None

from einops import rearrange
from .layer import single_all_to_all, apply_rotary_pos_emb
from .layer import single_all_to_all, apply_rotary_pos_emb, DistributedAttention
from deepspeed.utils.groups import _get_max_local_seq_len


def _rotate_half_backward(x):
Expand Down Expand Up @@ -81,19 +82,37 @@ class FPDT_InputConstruct(torch.nn.Module):
def __init__(self, tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank) -> None:

super(FPDT_InputConstruct, self).__init__()
self.tokens = tokens
self.labels = labels
self.loss_mask = loss_mask
self.attention_mask = attention_mask
self.position_ids = position_ids
global_seq_len = tokens.shape[1]
batch_size = tokens.shape[0]
raw_global_seq_len = tokens.shape[1]

# Pad global sequence length to the nearest multiple of sp_size so that
# each rank gets an equal-sized slice and the all-to-all splits are uniform.
pad_to = (raw_global_seq_len + sp_size - 1) // sp_size * sp_size
if pad_to != raw_global_seq_len:
import torch.nn.functional as _F
pad_amount = pad_to - raw_global_seq_len
tokens = _F.pad(tokens, (0, pad_amount))
if labels is not None:
labels = _F.pad(labels, (0, pad_amount))
if loss_mask is not None:
loss_mask = _F.pad(loss_mask, (0, pad_amount))
if position_ids is not None:
position_ids = _F.pad(position_ids, (0, pad_amount))
global_seq_len = tokens.shape[1]
self.raw_global_seq_len = raw_global_seq_len # kept for optional post-processing

assert global_seq_len % sp_size == 0
assert global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0
num_chunk_per_gpu = global_seq_len // args.ds_sequence_parallel_fpdt_chunk_size
local_seq_len = global_seq_len // sp_size
assert local_seq_len % num_chunk_per_gpu == 0

self.tokens = tokens
self.labels = labels
self.loss_mask = loss_mask
self.attention_mask = attention_mask
self.position_ids = position_ids

self.num_chunk_per_gpu = num_chunk_per_gpu
self.chunk_size = local_seq_len // num_chunk_per_gpu
self.sp_size = sp_size
Expand Down
49 changes: 48 additions & 1 deletion deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team
import torch
import torch.nn.functional as F

from typing import Any, Tuple
from torch import Tensor
Expand All @@ -14,6 +15,7 @@
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups
from deepspeed.utils.groups import _get_max_local_seq_len


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
Expand Down Expand Up @@ -44,7 +46,10 @@ def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
pre_all2all_permute_idx = None

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
# Scatter seq → gather heads. After permute the tensor has shape
# [global_seq_len // seq_world_size, bs, seq_world_size, num_local_head, head_dim];
# collapse the last two head dims to get the full head count on each rank.
post_all2all_res_shape = [global_seq_len // seq_world_size, bs, seq_world_size * num_local_head, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
Expand Down Expand Up @@ -364,6 +369,21 @@ def layer_sync(self, layer):
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
self.default_stream.wait_event(layer.done_event)

@staticmethod
def _pad_to_seq_world_size(tensor: Tensor, seq_dim: int, target_len: int) -> Tensor:
"""Pad *tensor* along *seq_dim* to *target_len* with zeros."""
pad_amount = target_len - tensor.shape[seq_dim]
if pad_amount == 0:
return tensor
# F.pad expects padding in reversed-dim order, two ints per dim: (pad_left, pad_right).
# Build a flat list of zeros with pad_amount at the right end of seq_dim.
pad_spec = [0] * (tensor.dim() * 2)
# Position for the RIGHT-side pad of seq_dim in the reversed list:
# F.pad list starts from the LAST dim. seq_dim from-the-end index = ndim - 1 - seq_dim.
# Right-side pad is at position 2 * (ndim - 1 - seq_dim) + 1.
pad_spec[2 * (tensor.dim() - 1 - seq_dim) + 1] = pad_amount
return F.pad(tensor, pad_spec)

def forward(self,
query: Tensor,
key: Tensor,
Expand All @@ -379,6 +399,8 @@ def forward(self,
key (Tensor): key input to the layer
value (Tensor): value input to the layer
batch_dim_idx (int): indicating which dim is batch
0 → batch-first [B, S, H, D]
1 → seq-first [S, B, H, D]
args: other args

Returns:
Expand All @@ -389,6 +411,24 @@ def forward(self,
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
#in shape : e.g., [s/p:h:]

# When the LOCAL sequence length is not the same across all ranks (which
# happens whenever the global sequence length is not divisible by the
# sequence-parallel world size), the all-to-all collective requires
# equal-sized contributions from every rank. Detect this case and pad
# Q/K/V to the maximum local sequence length before the all-to-all; the
# padding tokens are discarded from the final output.
seq_dim = 1 if batch_dim_idx == 0 else 0
local_seq_len = query.shape[seq_dim]
seq_world_size = dist.get_world_size(self.spg)

max_local_seq_len = _get_max_local_seq_len(local_seq_len, self.spg)
needs_pad = max_local_seq_len != local_seq_len

if needs_pad:
query = self._pad_to_seq_world_size(query, seq_dim, max_local_seq_len)
key = self._pad_to_seq_world_size(key, seq_dim, max_local_seq_len)
value = self._pad_to_seq_world_size(value, seq_dim, max_local_seq_len)

def bwd_hook(layer_type):

def pre_hook_fun(grad):
Expand Down Expand Up @@ -436,5 +476,12 @@ def pre_hook_fun(grad):
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx,
self.sp_stream, self.overlap_handles, 'o')

# Remove padding that was added to align sequence lengths across ranks.
if needs_pad:
if batch_dim_idx == 0:
output = output[:, :local_seq_len]
else:
output = output[:local_seq_len]

#out e.g., [s/p::h]
return output
16 changes: 16 additions & 0 deletions deepspeed/utils/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,22 @@ def _get_sequence_data_parallel_world_size():
return _get_data_parallel_world_size()


def _get_max_local_seq_len(local_seq_len: int, group) -> int:
"""Return the maximum local sequence length across all ranks in *group*.

Uses a single allreduce (MAX) so that all ranks agree on the same
padded sequence length. This is needed to keep the implicit
equal-sized splits in ``dist.all_to_all_single`` consistent when the
global sequence length is not evenly divisible by the world size.
"""
import torch
from deepspeed.accelerator import get_accelerator
device = get_accelerator().current_device_name()
local_tensor = torch.tensor(local_seq_len, dtype=torch.long, device=device)
dist.all_reduce(local_tensor, op=dist.ReduceOp.MAX, group=group)
return int(local_tensor.item())


def _get_sequence_data_parallel_rank():
"""Return my rank for the data parallel group."""
global mpu
Expand Down