Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device and dtype for LPIPS functional metric #2234

Merged
merged 9 commits into from
Nov 25, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222))


- Fix device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234))


## [1.2.0] - 2023-09-22

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,6 @@ def learned_perceptual_image_patch_similarity(
tensor(0.1008, grad_fn=<DivBackward0>)
"""
net = _NoTrainLpips(net=net_type)
net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype)
loss, total = _lpips_update(img1, img2, net, normalize)
return _lpips_compute(loss.sum(), total, reduction)
11 changes: 11 additions & 0 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from lpips import LPIPS as LPIPS_reference # noqa: N811
from torch import Tensor
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE

Expand Down Expand Up @@ -68,6 +69,16 @@ def test_lpips(self, net_type, ddp):
metric_args={"net_type": net_type},
)

def test_lpips_functional(self):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=_inputs.img1,
target=_inputs.img2,
metric_functional=learned_perceptual_image_patch_similarity,
reference_metric=partial(_compare_fn, net_type="alex"),
metric_args={"net_type": "alex"},
)

def test_lpips_differentiability(self):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
Expand Down
Loading