-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
Describe the bug
When training only bias parameters (i.e., requires_grad=True
only for bias) with ZeRO Stage 3 + AMP + gradient checkpointing, loss.backward()
fails with:
File "anaconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 97, in backward
if dim > 2:
UnboundLocalError: local variable 'dim' referenced before assignment
The issue is located in ZeroLinear.backward()
, where dim = grad_output.dim()
is only defined inside if ctx.needs_input_grad[1]
. However, dim
is later used in the bias gradient computation block without being conditionally guarded.
To Reproduce
Steps to reproduce the behavior:
- Use a model with gradient checkpointing enabled (e.g., LLaVA).
- Set all parameters except
bias
torequires_grad=False
. - Use DeepSpeed config with:
zero_optimization.stage=3
fp16.enabled=true
(or bf16)
- Launch training with
Trainer.train()
orloss.backward()
- See error during backward pass.
Expected behavior
Backward should proceed normally when only bias
is trainable. The function should not reference an undefined local variable.
ds_report output
Screenshots
N/A
System info (please complete the following information):
- OS: Ubuntu 20.04
- GPU count and types: 2x A800 80GB
- Python version: 3.10.16
- DeepSpeed version: 0.12.6
- PyTorch version: 2.1.2
Launcher context
Launched via HuggingFace Trainer
with DeepSpeed integration (deepspeed_config.json
)
Docker context
No
Additional context
A simple fix is to move dim = grad_output.dim()
above all if
blocks in ZeroLinear.backward()
to ensure it is always defined:
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
dim = grad_output.dim() # <- moved here
...