Skip to content

Commit f595645

Browse files
committed
test: update llama tests
1 parent d0256b3 commit f595645

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

tests/kit/model_zoo/transformers/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def data_gen_for_casual_lm():
4949
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
5050

5151
config = LlamaConfig(
52-
num_hidden_layers=4,
53-
hidden_size=128,
54-
intermediate_size=256,
52+
num_hidden_layers=8,
53+
hidden_size=32,
54+
intermediate_size=64,
5555
num_attention_heads=4,
5656
max_position_embeddings=128,
5757
num_labels=16,

tests/test_shardformer/test_model/test_shard_llama.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import colossalai
77
from colossalai.logging import disable_existing_loggers
8+
from colossalai.shardformer import AdvancedPipelineConfig
89
from colossalai.shardformer.layer.utils import Randomizer
910
from colossalai.tensor.d_tensor.api import clear_layout_converter
1011
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@@ -24,9 +25,13 @@
2425

2526

2627
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
28+
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
2729
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
2830
model_fn, loss_fn, test_config
2931
)
32+
if enable_gradient_checkpointing:
33+
org_model.gradient_checkpointing_enable()
34+
sharded_model.unwrap().gradient_checkpointing_enable()
3035

3136
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
3237
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
@@ -101,14 +106,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
101106
"use_lazy_init": True,
102107
"precision": "fp16",
103108
"initial_scale": 1,
104-
"gradient_checkpointing_ratio": 0.5,
109+
"enable_gradient_checkpointing": True,
110+
"advanced_pipeline_config": AdvancedPipelineConfig(gradient_checkpointing_ratio=0.5),
105111
},
106112
{
107113
"tp_size": 1,
108114
"pp_size": 2,
109115
"num_microbatches": 4,
110116
"use_lazy_init": False,
111117
"precision": "fp32",
118+
"enable_gradient_checkpointing": True,
119+
"advanced_pipeline_config": AdvancedPipelineConfig(
120+
num_stages=2, num_model_chunks=1, num_model_layers=8, num_layers_per_stage=[5, 3]
121+
),
112122
},
113123
{
114124
"tp_size": 4,
@@ -190,6 +200,14 @@ def run_llama_test(test_config):
190200
"precision": "fp16",
191201
"zero_stage": 1,
192202
"initial_scale": 1,
203+
"enable_gradient_checkpointing": True,
204+
"advanced_pipeline_config": AdvancedPipelineConfig(
205+
num_stages=2,
206+
num_model_chunks=2,
207+
num_model_layers=8,
208+
num_layers_per_stage=[3, 3, 1, 1],
209+
num_ckpt_layers_per_stage=[0, 0, 1, 1],
210+
),
193211
},
194212
],
195213
)

0 commit comments

Comments
 (0)