Skip to content

Fix import deepspeed crash on PyTorch v2.3 + Python 3.12#7875

Open
tohtana wants to merge 3 commits intodeepspeedai:masterfrom
tohtana:tohtana/fix_import_toarch_compile
Open

Fix import deepspeed crash on PyTorch v2.3 + Python 3.12#7875
tohtana wants to merge 3 commits intodeepspeedai:masterfrom
tohtana:tohtana/fix_import_toarch_compile

Conversation

@tohtana
Copy link
Collaborator

@tohtana tohtana commented Feb 26, 2026

import deepspeed raises RuntimeError: Dynamo is not supported on Python 3.12+ on PyTorch 2.3 + Python 3.12.
The jit_script_compat decorator (introduced in #7835) calls torch.compile() unconditionally on PyTorch >= 2.0, but Dynamo support for Python 3.12 was only added in PyTorch 2.4.
Multiple eager import chains trigger this decorator at import time, crashing before user code runs.

This PR adds a version gate to skip torch.compile on known-unsupported combinations, plus a double fallback (torch.compile → torch.jit.script → identity function) so the decorator won't crash.

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
@tohtana tohtana requested a review from tjruwase as a code owner February 26, 2026 03:13
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e1c74fb604

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant