Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.model

layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
Expand All @@ -129,10 +129,10 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
Expand Down
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
Expand Down Expand Up @@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
"""

Expand Down Expand Up @@ -969,6 +970,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__()
Expand Down Expand Up @@ -1043,6 +1045,7 @@ def __init__(
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/engine/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/engine/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
held_layers.append(module.output_layer)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/engine/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
Expand Down
87 changes: 86 additions & 1 deletion colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup

Expand Down Expand Up @@ -29,6 +30,8 @@ def __init__(
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"

self.num_layers_per_stage = None

self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
Expand Down Expand Up @@ -69,6 +72,88 @@ def __init__(
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None

@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None

def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.

Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage

def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)

if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage

else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks

# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

def get_stage_index(
self,
layers_per_stage: List[int],
stage: Optional[int] = None,
num_model_chunks: Optional[int] = None,
num_stages: Optional[int] = None,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.

Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks

Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk

"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_stages = self.num_stages if num_stages is None else num_stages

num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices

def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage.

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .shard import ShardConfig, ShardFormer
from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
14 changes: 13 additions & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,25 @@ def llama_model_forward(
next_decoder_cache = () if use_cache else None

start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
)
assert num_ckpt_layers <= end_idx - start_idx

for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if idx - start_idx < num_ckpt_layers:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
49 changes: 1 addition & 48 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -196,49 +195,3 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []

def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_stages // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

def get_stage_index(
self,
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.

Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks

Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk

"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices
32 changes: 8 additions & 24 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
module = self.model.bert

if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
Expand All @@ -298,8 +290,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
}

else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
Expand All @@ -324,16 +316,8 @@ def get_held_layers(self) -> List[Module]:
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
Expand All @@ -342,10 +326,10 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.pooler)

else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
Expand All @@ -226,11 +226,11 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
Expand Down
Loading