Conversation
…vice_type=...) Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
The amp() method on each accelerator returned a device-specific torch.<device>.amp module, but since PyTorch 2.4 the unified torch.amp API (torch.amp.custom_fwd, torch.amp.custom_bwd, torch.amp.autocast) accepts a device_type argument and works across all backends. The previous commit already migrated the two call sites; this commit removes the now-unused amp() abstract method and all 8 accelerator implementations, plus simplifies the custom_fwd/custom_bwd setup in zero/linear.py by dropping the pre-2.4 fallback path. Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a5b04f024f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| except (ImportError, AttributeError) as exp: | ||
| autocast_custom_fwd = noop_decorator | ||
| autocast_custom_bwd = noop_decorator | ||
| autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) |
There was a problem hiding this comment.
Restore compatibility guard for torch.amp decorators
Binding autocast_custom_fwd/autocast_custom_bwd directly to torch.amp.custom_fwd/custom_bwd at import time removes the previous compatibility path and will crash module import with AttributeError in environments where these symbols are missing (the old code explicitly handled that case). When this happens, deepspeed.runtime.zero.linear cannot be imported, so ZeRO-3 linear execution fails before runtime; this is a real compatibility regression given the repo still runs older torch variants in CI (for example .github/workflows/nv-transformers-v100.yml installs torch 2.0.1).
Useful? React with 👍 / 👎.
Pytorch now provide torch.amp https://docs.pytorch.org/docs/stable/amp.html as recommended AMP API instead of torch.<device_type>.amp which is used in DeepSpeed abstract accelerator amp(). Some PyTorch backend such as XPU does not provide the legacy
torch.xpu.ampmodule.This PR replace
get_accelerator().amp()bytorch.ampwhich is the recommended way of using AMP.Related issues and PRs
#7876
#7877