-
Notifications
You must be signed in to change notification settings - Fork 29.4k
Fix DTensor import compatibility for PyTorch < 2.5 #38836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got a few nits / general design choice question
src/transformers/modeling_utils.py
Outdated
try: | ||
from torch.distributed.tensor import DTensor | ||
except ImportError: | ||
DTensor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the try catch really necessary?
src/transformers/modeling_utils.py
Outdated
@@ -177,8 +177,12 @@ | |||
_is_ds_init_called = False | |||
_torch_distributed_available = torch.distributed.is_available() | |||
|
|||
DTensor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it make more sense to create a flag, e.g. _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
and when an isinstance(..., DTensor)
occurs, you simply add the flag check before (_is_dtensor_available and ...
).
Not a fan of having an instance check against None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vasqu for the helpful feedback!
I've updated the PR to replace the try/except pattern with a _is_dtensor_available
flag as you suggested. Now DTensor
is only imported when available, and all isinstance(..., DTensor)
checks are guarded with the flag.
Let me know if there's anything else you'd like to adjust.
8bb2a9f
to
c3c6449
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Could you revert the linter adjustments? Not sure what happened but now there is more than there should be 👀
c3c6449
to
5ae8eb4
Compare
Thanks for the feedback! I'm working on reverting the unintended linter changes now. I initially applied formatting to try to fix failing code quality checks (which were blocking the CI), but it seems that introduced more changes than necessary. I'll clean it up and make sure only the relevant DTensor logic remains in the PR. |
Thanks for sticking with me :D cc @Cyrilvallez for core maintaner |
What does this PR do?
This PR fixes a compatibility issue related to
DTensor
import introduced in PyTorch 2.5.Previously,
DTensor
was imported only under the condition:However, this led to a situation where DTensor is not defined at all in PyTorch versions below 2.5. As a result, any later use of:
isinstance(some_tensor, DTensor)
would raise a NameError, even if the conditional import was skipped. This PR addresses that issue.Changes made:
-Added a fallback DTensor = None to ensure the name is always defined.
-Updated downstream code to check if DTensor is not None before using isinstance(..., DTensor).
-Ensures safe and version-compatible handling of DTensor logic across PyTorch versions.
Fixes: N/A (no open issue, but addresses latent compatibility bug)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.