Skip to content

Commit

Permalink
fix zero3 (#1994)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Oct 28, 2024
1 parent 2501c1a commit d3c45d2
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 29 deletions.
9 changes: 0 additions & 9 deletions deepspeed_configs/zero3_bf16.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
Expand Down
9 changes: 0 additions & 9 deletions deepspeed_configs/zero3_bf16_cpuoffload_all.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
Expand Down
9 changes: 0 additions & 9 deletions deepspeed_configs/zero3_bf16_cpuoffload_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.14.4
deepspeed==0.15.3
pydantic==2.6.3
addict
fire
Expand Down
41 changes: 40 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
PreTrainedTokenizerBase,
ProcessorMixin,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import (
HfTrainerDeepSpeedConfig,
is_deepspeed_zero3_enabled,
)

from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
Expand Down Expand Up @@ -705,6 +708,38 @@ def set_attention_config(self) -> None:
self.model_kwargs["low_cpu_mem_usage"] = True

def build_model(self, qlora_fsdp) -> bool:
def _configure_zero3_memory_efficient_loading():
"""
Set the deepspeed config to load the model into RAM first before moving to VRAM.
We need to return hf_ds_cfg as it needs to exist before model loading.
"""
hf_ds_cfg = None

if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed)
hf_ds_cfg.fill_match(
"train_micro_batch_size_per_gpu", self.cfg.micro_batch_size
)
hf_ds_cfg.fill_match(
"gradient_accumulation_steps", self.cfg.gradient_accumulation_steps
)
hf_ds_cfg.fill_match(
"train_batch_size",
int(os.getenv("WORLD_SIZE", "1"))
* self.cfg.micro_batch_size
* self.cfg.gradient_accumulation_steps,
)
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]

transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
lambda: True
)

return hf_ds_cfg

skip_move_to_device = False
if ( # pylint: disable=condition-evals-to-constant)
(self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
Expand Down Expand Up @@ -753,6 +788,8 @@ def build_model(self, qlora_fsdp) -> bool:
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]

_ = _configure_zero3_memory_efficient_loading()

if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
Expand Down Expand Up @@ -846,6 +883,8 @@ def build_model(self, qlora_fsdp) -> bool:
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]

_ = _configure_zero3_memory_efficient_loading()

if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
Expand Down

0 comments on commit d3c45d2

Please sign in to comment.