Skip to content

Fix: no_grad with AMP bug #20921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def clip_gradients(
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

def autocast_context_manager(self) -> torch.autocast:
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False
)

@override
@contextmanager
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from unittest.mock import Mock

import pytest
import torch
from torch import nn
from torch.optim import Optimizer

from lightning.pytorch.plugins import MixedPrecision
Expand Down Expand Up @@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method():

with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)


def test_amp_with_no_grad():
"""Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient
tracking."""
layer = nn.Linear(2, 1)
x = torch.randn(1, 2)
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

with amp.autocast_context_manager():
with torch.no_grad():
_ = layer(x)

loss = layer(x).mean()
loss.backward()
assert loss.grad_fn is not None
Loading