diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json index 16e64d76b..49fb75755 100644 --- a/deepspeed_configs/zero3_bf16.json +++ b/deepspeed_configs/zero3_bf16.json @@ -14,15 +14,6 @@ "bf16": { "enabled": true }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_all.json b/deepspeed_configs/zero3_bf16_cpuoffload_all.json index 09ca6785b..3ccc66db4 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_all.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_all.json @@ -24,15 +24,6 @@ "bf16": { "enabled": true }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_params.json b/deepspeed_configs/zero3_bf16_cpuoffload_params.json index 41d4a2132..fe21d35f8 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_params.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_params.json @@ -20,15 +20,6 @@ "bf16": { "enabled": true }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", diff --git a/requirements.txt b/requirements.txt index b6e9a554e..6bb1aa684 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tokenizers>=0.20.1 bitsandbytes==0.44.1 accelerate==1.0.1 datasets==3.0.1 -deepspeed==0.14.4 +deepspeed==0.15.3 pydantic==2.6.3 addict fire diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5e53df72c..8b433c366 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -40,7 +40,10 @@ PreTrainedTokenizerBase, ProcessorMixin, ) -from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import ( + HfTrainerDeepSpeedConfig, + is_deepspeed_zero3_enabled, +) from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.models.mamba import fix_mamba_attn_for_loss @@ -705,6 +708,38 @@ def set_attention_config(self) -> None: self.model_kwargs["low_cpu_mem_usage"] = True def build_model(self, qlora_fsdp) -> bool: + def _configure_zero3_memory_efficient_loading(): + """ + Set the deepspeed config to load the model into RAM first before moving to VRAM. + + We need to return hf_ds_cfg as it needs to exist before model loading. + """ + hf_ds_cfg = None + + if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": + hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) + hf_ds_cfg.fill_match( + "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size + ) + hf_ds_cfg.fill_match( + "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps + ) + hf_ds_cfg.fill_match( + "train_batch_size", + int(os.getenv("WORLD_SIZE", "1")) + * self.cfg.micro_batch_size + * self.cfg.gradient_accumulation_steps, + ) + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True + transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( + lambda: True + ) + + return hf_ds_cfg + skip_move_to_device = False if ( # pylint: disable=condition-evals-to-constant) (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) @@ -753,6 +788,8 @@ def build_model(self, qlora_fsdp) -> bool: if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] + _ = _configure_zero3_memory_efficient_loading() + if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained( @@ -846,6 +883,8 @@ def build_model(self, qlora_fsdp) -> bool: if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] + _ = _configure_zero3_memory_efficient_loading() + if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained(