|
5 | 5 |
|
6 | 6 | import colossalai |
7 | 7 | from colossalai.logging import disable_existing_loggers |
| 8 | +from colossalai.shardformer import AdvancedPipelineConfig |
8 | 9 | from colossalai.shardformer.layer.utils import Randomizer |
9 | 10 | from colossalai.tensor.d_tensor.api import clear_layout_converter |
10 | 11 | from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn |
|
24 | 25 |
|
25 | 26 |
|
26 | 27 | def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): |
| 28 | + enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) |
27 | 29 | org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( |
28 | 30 | model_fn, loss_fn, test_config |
29 | 31 | ) |
| 32 | + if enable_gradient_checkpointing: |
| 33 | + org_model.gradient_checkpointing_enable() |
| 34 | + sharded_model.unwrap().gradient_checkpointing_enable() |
30 | 35 |
|
31 | 36 | org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( |
32 | 37 | org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster |
@@ -101,14 +106,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, |
101 | 106 | "use_lazy_init": True, |
102 | 107 | "precision": "fp16", |
103 | 108 | "initial_scale": 1, |
104 | | - "gradient_checkpointing_ratio": 0.5, |
| 109 | + "enable_gradient_checkpointing": True, |
| 110 | + "advanced_pipeline_config": AdvancedPipelineConfig(gradient_checkpointing_ratio=0.5), |
105 | 111 | }, |
106 | 112 | { |
107 | 113 | "tp_size": 1, |
108 | 114 | "pp_size": 2, |
109 | 115 | "num_microbatches": 4, |
110 | 116 | "use_lazy_init": False, |
111 | 117 | "precision": "fp32", |
| 118 | + "enable_gradient_checkpointing": True, |
| 119 | + "advanced_pipeline_config": AdvancedPipelineConfig( |
| 120 | + num_stages=2, num_model_chunks=1, num_model_layers=8, num_layers_per_stage=[5, 3] |
| 121 | + ), |
112 | 122 | }, |
113 | 123 | { |
114 | 124 | "tp_size": 4, |
@@ -190,6 +200,14 @@ def run_llama_test(test_config): |
190 | 200 | "precision": "fp16", |
191 | 201 | "zero_stage": 1, |
192 | 202 | "initial_scale": 1, |
| 203 | + "enable_gradient_checkpointing": True, |
| 204 | + "advanced_pipeline_config": AdvancedPipelineConfig( |
| 205 | + num_stages=2, |
| 206 | + num_model_chunks=2, |
| 207 | + num_model_layers=8, |
| 208 | + num_layers_per_stage=[3, 3, 1, 1], |
| 209 | + num_ckpt_layers_per_stage=[0, 0, 1, 1], |
| 210 | + ), |
193 | 211 | }, |
194 | 212 | ], |
195 | 213 | ) |
|
0 commit comments