Skip to content

Commit

Permalink
Merge branch 'bugfix/error_cosine_sim' of https://github.com/Lightnin…
Browse files Browse the repository at this point in the history
…g-AI/torchmetrics into bugfix/error_cosine_sim
  • Loading branch information
SkafteNicki committed Nov 27, 2023
2 parents f4e96c4 + 68802b6 commit bd2b36e
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 16 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))


- Use arange and repeat for deterministic bincount ([#2184](https://github.com/Lightning-AI/torchmetrics/pull/2184))


### Deprecated

- Deprecated `metric._update_called` ([#2141](https://github.com/Lightning-AI/torchmetrics/pull/2141))
Expand All @@ -46,7 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-
- Removed `lpips` third-party package as dependency of `LearnedPerceptualImagePatchSimilarity` metric ([#2230](https://github.com/Lightning-AI/torchmetrics/pull/2230))


### Fixed
Expand Down
1 change: 0 additions & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
scipy >1.0.0, <1.11.0
torchvision >=0.8, <0.17.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
lpips <=0.1.4
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pytorch-msssim ==1.0.0
sewar >=0.4.4, <=0.4.6
numpy <1.25.0
torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master
lpips <=0.1.4
6 changes: 3 additions & 3 deletions src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the
chosen backbone (see `net_type` arg).
.. note:: using this metrics requires you to have ``lpips`` package installed. Either install
as ``pip install torchmetrics[image]`` or ``pip install lpips``
.. note:: using this metrics requires you to have ``torchvision`` package installed. Either install as
``pip install torchmetrics[image]`` or ``pip install torchvision``.
.. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation
if this is a issue.
Expand All @@ -71,7 +71,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
Raises:
ModuleNotFoundError:
If ``lpips`` package is not installed
If ``torchvision`` package is not installed
ValueError:
If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"``
ValueError:
Expand Down
18 changes: 8 additions & 10 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,10 @@ def _squeeze_if_scalar(data: Any) -> Any:
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""Implement custom bincount.
PyTorch currently does not support ``torch.bincount`` for:
- deterministic mode on GPU.
- MPS devices
This implementation fallback to a for-loop counting occurrences in that case.
PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
Args:
x: tensor to count
Expand All @@ -191,11 +189,11 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""
if minlength is None:
minlength = len(torch.unique(x))

if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
return output
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)

return torch.bincount(x, minlength=minlength)


Expand Down
5 changes: 4 additions & 1 deletion tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCHVISION_AVAILABLE

from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
Expand Down Expand Up @@ -48,6 +48,7 @@ def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool = Fal
return res.sum()


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
class TestLPIPS(MetricTester):
"""Test class for `LearnedPerceptualImagePatchSimilarity` metric."""
Expand Down Expand Up @@ -105,6 +106,7 @@ def test_normalize_arg(normalize):
assert res == res2


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
Expand All @@ -115,6 +117,7 @@ def test_error_on_wrong_init():
LearnedPerceptualImagePatchSimilarity(net_type="squeeze", reduction=None)


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize(
("inp1", "inp2"),
Expand Down

0 comments on commit bd2b36e

Please sign in to comment.