Skip to content

Upgrade from 4.37.2 to 4.38.2 causes CUDA out of memory error with identical configuration. #29484

@richardodliu

Description

@richardodliu

System Info

  • transformers version: 4.38.2
  • Platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.31
  • Python version: 3.9.18
  • Huggingface_hub version: 0.21.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1
  • Tensorflow version (GPU?): not installed
  • Flax version (CPU?/GPU?/TPU?): not installed
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@muellerz
@pacman100
@gante
@Narsil

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I was using the official training script provided by DeepSeek-Coder (https://github.com/deepseek-ai/DeepSeek-Coder/blob/main/finetune/finetune_deepseekcoder.py) to train a model on the dataset (https://huggingface.co/datasets/nickrosh/Evol-Instruct-Code-80k-v1) when I noticed that using transformers version 4.38.2 results in an error: torch.cuda.OutOfMemoryError: CUDA out of memory. However, simply downgrading transformers to version 4.37.2 resolves this issue.

The command line I used was as follows:

#!/bin/bash
export DS_SKIP_CUDA_CHECK=1

DATA_PATH="/data2/user/llm/Evol-Instruct-Code-80k-v1/EvolInstruct-Code-80k.json"
OUTPUT_PATH="/data2/user/llm/sft/0306"
MODEL_PATH="/data2/user/llm/deepseek-coder-1.3b-base"

export WANDB_PROJECT="llm"
export WANDB_LOG_MODEL="false"
export WANDB_WATCH="false"

deepspeed --include localhost:0,1,2,3,4,5,6,7 DeepSeek-Coder/finetune/finetune_deepseekcoder.py \
    --model_name_or_path $MODEL_PATH \
    --data_path $DATA_PATH \
    --output_dir $OUTPUT_PATH \
    --num_train_epochs 3 \
    --model_max_length 1024 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "epoch" \
    --save_steps 100 \
    --save_total_limit 100 \
    --learning_rate 2e-5 \
    --warmup_steps 10 \
    --lr_scheduler_type "cosine" \
    --gradient_checkpointing True \
    --report_to "wandb" \
    --logging_steps 1 \
    --run_name "test"
    --deepspeed DeepSeek-Coder/finetune/configs/ds_config_zero3.json \
    --bf16 True \
    --seed 42

The details error information can be found as:

Traceback (most recent call last):
  File "/data2/user/llm/DeepSeek-Coder/finetune/finetune_deepseekcoder.py", line 200, in <module>
    train()
  File "/data2/user/llm/DeepSeek-Coder/finetune/finetune_deepseekcoder.py", line 194, in train
    trainer.train()
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/trainer.py", line 2902, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/trainer.py", line 2925, in compute_loss
    outputs = model(**inputs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1176, in forward
    outputs = self.model(
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 993, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
  File "/home/user/anaconda3/envs/llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1093, in _update_causal_mask
    causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB. GPU 6 has a total capacity of 79.15 GiB of which 8.24 GiB is free. Including non-PyTorch memory, this process has 70.91 GiB memory in use. Of the allocated memory 38.74 GiB is allocated by PyTorch, and 31.00 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

I noticed that when this error occurs, the monitor indeed shows that there is not enough space on the GPU. However, at this time, only the training program is running on the server. This phenomenon seems to imply that while the program is occupying a large amount of VRAM, it is still requesting more VRAM space, a problem that was not encountered in the previous version of the transformers.

Given my repeated experiments, I believe we can essentially rule out the script or parameter settings as the source of the issue. Since I've utilized DeepSpeed to accelerate training, I personally think that this problem might occur during the parallel training process. Particularly, the code changes introduced by this major version update seem likely to be a contributing factor to the issue I'm encountering.

Expected behavior

First of all, I would like to express my gratitude to the developers and maintainers for their continuous effort in maintaining and improving this repository. Your hard work is truly appreciated by the community.
Regarding the expected behavior, it would be for the model to train successfully without encountering a torch.cuda.OutOfMemoryError, leveraging GPU memory as efficiently as it did in the previous transformers library version (version 4.37.2). The ideal scenario involves the program dynamically managing memory allocation to prevent the depletion of GPU resources, thus ensuring a stable and efficient training process. This level of performance was observed in the past and would be anticipated to continue without necessitating manual adjustments or a downgrade to an earlier version.
Thanks for you reviewing again.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions