Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils._pytree import tree_map

from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -327,7 +327,10 @@ def run_forward_only(
self.send_forward(output_obj)

if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}

def run_forward_backward(
Expand Down Expand Up @@ -410,7 +413,10 @@ def run_forward_backward(
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)

if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}

def forward_backward_step(
Expand Down
121 changes: 79 additions & 42 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
- [x] Unit Testing
- [ ] Policy Implementation

| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |


## 💡 API Design
Expand Down Expand Up @@ -391,6 +391,43 @@ _POLICY_LIST = {
}
```

#### How to support those models in huggingface model hub but not in the transformers library

There are two cases:

1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".

Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.

Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.

E.g. for llama:
```python
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
```

for chatglm2:
```python
policy["GLMBlock"] = ModulePolicyDescription(...)
```

Then when registering such models in the autopolicy, we should follow below format:
```python
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
file_name="<policy_filename>", class_name="<policy_class_name>"
)
```

As for chatglm2 model, it should be:
```python
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
)
```

When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.

### Write Your Unit Testing

This section serves as the guideline for testing the `shardformer` module.
Expand Down Expand Up @@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.

In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |


<p align="center">
Expand All @@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.

In the case of using 4 GPUs, the training times are as follows.

| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |



Expand Down Expand Up @@ -475,10 +512,10 @@ warmup_fraction = 0.03


| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|:--------:|:-------:|:-------:|:----------:|:-------------:|
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |


Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
19 changes: 8 additions & 11 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,19 +275,16 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg
)

LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine

# try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)

rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)

rmsnorm.weight = module.weight
Expand Down
21 changes: 15 additions & 6 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel


def get_flash_core_attention_forward():
Expand All @@ -31,7 +30,12 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
device=query_layer.device,
)
temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
Expand All @@ -49,6 +53,7 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
Expand Down Expand Up @@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:

@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
self: "ChatGLMModel",
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
Expand Down Expand Up @@ -194,7 +199,9 @@ def chatglm_model_forward(
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
Expand Down Expand Up @@ -224,7 +231,9 @@ def chatglm_model_forward(
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -254,7 +263,7 @@ def chatglm_model_forward(

@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down
13 changes: 10 additions & 3 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ class PolicyLocation:
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
),
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Falcon
Expand Down Expand Up @@ -202,6 +202,13 @@ def _fullname(obj):
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__


Expand All @@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:

if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
Expand Down
Loading