-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Fix: UnboundLocalError for variable 'dim' about issue #7449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
loadams
merged 3 commits into
deepspeedai:master
from
weeknan:fix-UnboundLocalError-bug
Jul 28, 2025
Merged
Fix: UnboundLocalError for variable 'dim' about issue #7449
loadams
merged 3 commits into
deepspeedai:master
from
weeknan:fix-UnboundLocalError-bug
Jul 28, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: weeknan <zhounan0431@163.com>
0fbabd3
to
c18caf0
Compare
hwchen2017
approved these changes
Jul 24, 2025
loadams
approved these changes
Jul 28, 2025
lpnpcs
pushed a commit
to lpnpcs/DeepSpeed
that referenced
this pull request
Jul 30, 2025
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only bias parameters, as mentioned in deepspeedai#7435 This PR addresses an issue in the `ZeroLinear.backward()` method, where the local variable `dim` could be referenced before assignment. This happens specifically when: - Only the bias parameters are set to `requires_grad=True`, and - The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient checkpointing**. ### Problem When only the bias requires gradients, the condition for setting `dim = grad_output.dim()` is skipped, but the value of `dim` is still used later in the computation, leading to: ### Fix Move the assignment `dim = grad_output.dim()` to occur unconditionally, so that `dim` is always defined before being used in any branch of the gradient computation logic. ### Impact This makes the backward pass more robust across different training setups. Signed-off-by: weeknan <zhounan0431@163.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
qimcis
pushed a commit
to qimcis/DeepSpeed
that referenced
this pull request
Jul 31, 2025
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only bias parameters, as mentioned in deepspeedai#7435 This PR addresses an issue in the `ZeroLinear.backward()` method, where the local variable `dim` could be referenced before assignment. This happens specifically when: - Only the bias parameters are set to `requires_grad=True`, and - The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient checkpointing**. ### Problem When only the bias requires gradients, the condition for setting `dim = grad_output.dim()` is skipped, but the value of `dim` is still used later in the computation, leading to: ### Fix Move the assignment `dim = grad_output.dim()` to occur unconditionally, so that `dim` is always defined before being used in any branch of the gradient computation logic. ### Impact This makes the backward pass more robust across different training setups. Signed-off-by: weeknan <zhounan0431@163.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: qimcis <chixie.mcisaac@gmail.com>
LYMDLUT
pushed a commit
to LYMDLUT/DeepSpeed
that referenced
this pull request
Aug 20, 2025
## Fix `UnboundLocalError` in `ZeroLinear.backward()` when training only bias parameters, as mentioned in deepspeedai#7435 This PR addresses an issue in the `ZeroLinear.backward()` method, where the local variable `dim` could be referenced before assignment. This happens specifically when: - Only the bias parameters are set to `requires_grad=True`, and - The training setup uses **ZeRO Stage 3**, **AMP**, and **gradient checkpointing**. ### Problem When only the bias requires gradients, the condition for setting `dim = grad_output.dim()` is skipped, but the value of `dim` is still used later in the computation, leading to: ### Fix Move the assignment `dim = grad_output.dim()` to occur unconditionally, so that `dim` is always defined before being used in any branch of the gradient computation logic. ### Impact This makes the backward pass more robust across different training setups. Signed-off-by: weeknan <zhounan0431@163.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: lym <letusgo126@126.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix
UnboundLocalError
inZeroLinear.backward()
when training only bias parameters, as mentioned in #7435This PR addresses an issue in the
ZeroLinear.backward()
method, where the local variabledim
could be referenced before assignment. This happens specifically when:requires_grad=True
, andProblem
When only the bias requires gradients, the condition for setting
dim = grad_output.dim()
is skipped, but the value ofdim
is still used later in the computation, leading to:Fix
Move the assignment
dim = grad_output.dim()
to occur unconditionally, so thatdim
is always defined before being used in any branch of the gradient computation logic.Impact
This makes the backward pass more robust across different training setups.