Skip to content

Commit 2c3629a

Browse files
SkafteNickiBorda
authored andcommitted
Fix backprop in LPIPS (Lightning-AI#2326)
* fixes + tests * fix doctests (cherry picked from commit 18b181d)
1 parent 2ae0c8b commit 2c3629a

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727

2828
### Fixed
2929

30+
- Fixed how backprop is handled in `LPIPS` metric ([#2326](https://github.com/Lightning-AI/torchmetrics/pull/2326))
31+
32+
3033
- Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349))
3134

3235

src/torchmetrics/functional/image/lpips.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(
275275
net: Indicate backbone to use, choose between ['alex','vgg','squeeze']
276276
spatial: If input should be spatial averaged
277277
pnet_rand: If backbone should be random or use imagenet pre-trained weights
278-
pnet_tune: If backprop should be enabled
278+
pnet_tune: If backprop should be enabled for both backbone and linear layers
279279
use_dropout: If dropout layers should be added
280280
model_path: Model path to load pretained models from
281281
eval_mode: If network should be in evaluation mode
@@ -327,6 +327,10 @@ def __init__(
327327
if eval_mode:
328328
self.eval()
329329

330+
if not self.pnet_tune:
331+
for param in self.parameters():
332+
param.requires_grad = False
333+
330334
def forward(
331335
self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False
332336
) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
@@ -423,7 +427,7 @@ def learned_perceptual_image_patch_similarity(
423427
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
424428
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
425429
>>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze')
426-
tensor(0.1008, grad_fn=<DivBackward0>)
430+
tensor(0.1008)
427431
428432
"""
429433
net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype)

src/torchmetrics/image/lpip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
8686
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
8787
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
8888
>>> lpips(img1, img2)
89-
tensor(0.1046, grad_fn=<SqueezeBackward0>)
89+
tensor(0.1046)
9090
9191
"""
9292

tests/unittests/image/test_lpips.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,15 @@ def test_error_on_wrong_update(inp1, inp2):
133133
metric = LearnedPerceptualImagePatchSimilarity()
134134
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
135135
metric(inp1, inp2)
136+
137+
138+
def test_check_for_backprop():
139+
"""Check that by default the metric supports propagation of gradients, but does not update its parameters."""
140+
metric = LearnedPerceptualImagePatchSimilarity()
141+
assert not metric.net.lin0.model[1].weight.requires_grad
142+
preds, target = _inputs.img1[0], _inputs.img2[0]
143+
preds.requires_grad = True
144+
loss = metric(preds, target)
145+
assert loss.requires_grad
146+
loss.backward()
147+
assert metric.net.lin0.model[1].weight.grad is None

0 commit comments

Comments
 (0)