Skip to content

Fix: UnboundLocalError for variable 'dim' about issue #7449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 28, 2025

Conversation

weeknan
Copy link
Contributor

@weeknan weeknan commented Jul 24, 2025

Fix UnboundLocalError in ZeroLinear.backward() when training only bias parameters, as mentioned in #7435

This PR addresses an issue in the ZeroLinear.backward() method, where the local variable dim could be referenced before assignment. This happens specifically when:

  • Only the bias parameters are set to requires_grad=True, and
  • The training setup uses ZeRO Stage 3, AMP, and gradient checkpointing.

Problem

When only the bias requires gradients, the condition for setting dim = grad_output.dim() is skipped, but the value of dim is still used later in the computation, leading to:

Fix

Move the assignment dim = grad_output.dim() to occur unconditionally, so that dim is always defined before being used in any branch of the gradient computation logic.

Impact

This makes the backward pass more robust across different training setups.

@weeknan weeknan requested review from tjruwase and tohtana as code owners July 24, 2025 12:28
@weeknan weeknan changed the title Fix: UnboundLocalError for variable 'dim' about issue #7435 Fix: UnboundLocalError for variable 'dim' about issue Jul 24, 2025
Signed-off-by: weeknan <zhounan0431@163.com>
@weeknan weeknan force-pushed the fix-UnboundLocalError-bug branch from 0fbabd3 to c18caf0 Compare July 24, 2025 12:33
@loadams loadams merged commit 56fed13 into deepspeedai:master Jul 28, 2025
9 checks passed
@weeknan weeknan deleted the fix-UnboundLocalError-bug branch July 29, 2025 14:25
lpnpcs pushed a commit to lpnpcs/DeepSpeed that referenced this pull request Jul 30, 2025
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only
bias parameters, as mentioned in deepspeedai#7435

This PR addresses an issue in the `ZeroLinear.backward()` method, where
the local variable `dim` could be referenced before assignment. This
happens specifically when:

- Only the bias parameters are set to `requires_grad=True`, and
- The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient
checkpointing**.

###  Problem

When only the bias requires gradients, the condition for setting `dim =
grad_output.dim()` is skipped, but the value of `dim` is still used
later in the computation, leading to:


###  Fix

Move the assignment `dim = grad_output.dim()` to occur unconditionally,
so that `dim` is always defined before being used in any branch of the
gradient computation logic.

###  Impact

This makes the backward pass more robust across different training
setups.

Signed-off-by: weeknan <zhounan0431@163.com>
Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
qimcis pushed a commit to qimcis/DeepSpeed that referenced this pull request Jul 31, 2025
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only
bias parameters, as mentioned in deepspeedai#7435

This PR addresses an issue in the `ZeroLinear.backward()` method, where
the local variable `dim` could be referenced before assignment. This
happens specifically when:

- Only the bias parameters are set to `requires_grad=True`, and
- The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient
checkpointing**.

###  Problem

When only the bias requires gradients, the condition for setting `dim =
grad_output.dim()` is skipped, but the value of `dim` is still used
later in the computation, leading to:

###  Fix

Move the assignment `dim = grad_output.dim()` to occur unconditionally,
so that `dim` is always defined before being used in any branch of the
gradient computation logic.

###  Impact

This makes the backward pass more robust across different training
setups.

Signed-off-by: weeknan <zhounan0431@163.com>
Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: qimcis <chixie.mcisaac@gmail.com>
LYMDLUT pushed a commit to LYMDLUT/DeepSpeed that referenced this pull request Aug 20, 2025
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only
bias parameters, as mentioned in deepspeedai#7435

This PR addresses an issue in the `ZeroLinear.backward()` method, where
the local variable `dim` could be referenced before assignment. This
happens specifically when:

- Only the bias parameters are set to `requires_grad=True`, and
- The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient
checkpointing**.

###  Problem

When only the bias requires gradients, the condition for setting `dim =
grad_output.dim()` is skipped, but the value of `dim` is still used
later in the computation, leading to:

###  Fix

Move the assignment `dim = grad_output.dim()` to occur unconditionally,
so that `dim` is always defined before being used in any branch of the
gradient computation logic.

###  Impact

This makes the backward pass more robust across different training
setups.

Signed-off-by: weeknan <zhounan0431@163.com>
Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: lym <letusgo126@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants