|
26 | 26 | from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper |
27 | 27 | from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule |
28 | 28 | from colossalai.pipeline.stage_manager import PipelineStageManager |
29 | | -from colossalai.shardformer import GradCkptCollection, ShardConfig, ShardFormer |
| 29 | +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer |
30 | 30 | from colossalai.shardformer.layer.utils import SeqParallelUtils |
31 | 31 | from colossalai.shardformer.policies.base_policy import Policy |
32 | 32 | from colossalai.tensor.d_tensor.api import is_distributed_tensor |
@@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): |
930 | 930 | custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. |
931 | 931 | pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. |
932 | 932 | num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. |
933 | | - gradient_ckpt_collection (GradCkptCollection, optional): The configuration for gradient checkpointing. Defaults to None. |
| 933 | + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. |
934 | 934 | enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. |
935 | 935 | """ |
936 | 936 |
|
@@ -970,7 +970,7 @@ def __init__( |
970 | 970 | custom_policy: Policy = None, |
971 | 971 | pp_style: str = "1f1b", |
972 | 972 | num_model_chunks: int = 1, |
973 | | - gradient_ckpt_collection: Optional[GradCkptCollection] = None, |
| 973 | + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, |
974 | 974 | enable_metadata_cache: bool = True, |
975 | 975 | ) -> None: |
976 | 976 | super().__init__() |
@@ -1045,7 +1045,7 @@ def __init__( |
1045 | 1045 | enable_sequence_parallelism=enable_sequence_parallelism, |
1046 | 1046 | enable_sequence_overlap=enable_sequence_overlap, |
1047 | 1047 | parallel_output=parallel_output, |
1048 | | - gradient_ckpt_collection=gradient_ckpt_collection, |
| 1048 | + gradient_checkpoint_config=gradient_checkpoint_config, |
1049 | 1049 | ) |
1050 | 1050 | self.amp_config = dict( |
1051 | 1051 | initial_scale=initial_scale, |
|
0 commit comments