Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/metric_collection_and_aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jul 11, 2023
2 parents 208f6e1 + 93ac13f commit f70b28d
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `VisualInformationFidelity` to image package ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1830))


### Changed
Expand Down
23 changes: 23 additions & 0 deletions docs/source/image/visual_information_fidelity.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Visual Information Fidelity (VIF)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#################################
Visual Information Fidelity (VIF)
#################################

Module Interface
________________

.. autoclass:: torchmetrics.image.VisualInformationFidelity
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.visual_information_fidelity
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@
.. _Equal opportunity: https://proceedings.neurips.cc/paper/2016/hash/9d2682367c3935defcb1f9e247a97c0d-Abstract.html
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
.. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa
.. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.image.vif import visual_information_fidelity

__all__ = [
"spectral_distortion_index",
Expand All @@ -39,4 +40,5 @@
"structural_similarity_index_measure",
"total_variation",
"universal_image_quality_index",
"visual_information_fidelity",
]
114 changes: 114 additions & 0 deletions src/torchmetrics/functional/image/vif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from torch.nn.functional import conv2d

from torchmetrics.utilities.distributed import reduce


def _filter(win_size: float, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
# This code is inspired by
# https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/utils.py#L45
# https://github.com/photosynthesis-team/piq/blob/01e16b7d8c76bc8765fb6a69560d806148b8046a/piq/functional/filters.py#L38
# Both links do the same, but the second one is cleaner
coords = torch.arange(win_size, dtype=dtype, device=device) - (win_size - 1) / 2
g = coords**2
g = torch.exp(-(g.unsqueeze(0) + g.unsqueeze(1)) / (2.0 * sigma**2))
g /= torch.sum(g)
return g


def _vif_per_channel(preds: Tensor, target: Tensor, sigma_n_sq: float) -> Tensor:
dtype = preds.dtype
device = preds.device

preds = preds.unsqueeze(1) # Add channel dimension
target = target.unsqueeze(1)
# Constant for numerical stability
eps = torch.tensor(1e-10, dtype=dtype, device=device)

sigma_n_sq = torch.tensor(sigma_n_sq, dtype=dtype, device=device)

preds_vif, target_vif = torch.zeros(1, dtype=dtype, device=device), torch.zeros(1, dtype=dtype, device=device)
for scale in range(4):
n = 2.0 ** (4 - scale) + 1
kernel = _filter(n, n / 5, dtype=dtype, device=device)[None, None, :]

if scale > 0:
target = conv2d(target, kernel)[:, :, ::2, ::2]
preds = conv2d(preds, kernel)[:, :, ::2, ::2]

mu_target = conv2d(target, kernel)
mu_preds = conv2d(preds, kernel)
mu_target_sq = mu_target**2
mu_preds_sq = mu_preds**2
mu_target_preds = mu_target * mu_preds

sigma_target_sq = torch.clamp(conv2d(target**2, kernel) - mu_target_sq, min=0.0)
sigma_preds_sq = torch.clamp(conv2d(preds**2, kernel) - mu_preds_sq, min=0.0)
sigma_target_preds = conv2d(target * preds, kernel) - mu_target_preds

g = sigma_target_preds / (sigma_target_sq + eps)
sigma_v_sq = sigma_preds_sq - g * sigma_target_preds

mask = sigma_target_sq < eps
g[mask] = 0
sigma_v_sq[mask] = sigma_preds_sq[mask]
sigma_target_sq[mask] = 0

mask = sigma_preds_sq < eps
g[mask] = 0
sigma_v_sq[mask] = 0

mask = g < 0
sigma_v_sq[mask] = sigma_preds_sq[mask]
g[mask] = 0
sigma_v_sq = torch.clamp(sigma_v_sq, min=eps)

preds_vif_scale = torch.log10(1.0 + (g**2.0) * sigma_target_sq / (sigma_v_sq + sigma_n_sq))
preds_vif = preds_vif + torch.sum(preds_vif_scale, dim=[1, 2, 3])
target_vif = target_vif + torch.sum(torch.log10(1.0 + sigma_target_sq / sigma_n_sq), dim=[1, 2, 3])
return preds_vif / target_vif


def visual_information_fidelity(preds: Tensor, target: Tensor, sigma_n_sq: float = 2.0) -> Tensor:
"""Compute Pixel Based Visual Information Fidelity (VIF_).
Args:
preds: predicted images of shape ``(N,C,H,W)``. ``(H, W)`` has to be at least ``(41, 41)``.
target: ground truth images of shape ``(N,C,H,W)``. ``(H, W)`` has to be at least ``(41, 41)``
sigma_n_sq: variance of the visual noise
Return:
Tensor with vif-p score
Raises:
ValueError:
If ``data_range`` is neither a ``tuple`` nor a ``float``
"""
# This code is inspired by
# https://github.com/photosynthesis-team/piq/blob/01e16b7d8c76bc8765fb6a69560d806148b8046a/piq/vif.py and
# https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/full_ref.py#L357.

if preds.size(-1) < 41 or preds.size(-2) < 41:
raise ValueError(f"Invalid size of preds. Expected at least 41x41, but got {preds.size(-1)}x{preds.size(-2)}!")

if target.size(-1) < 41 or target.size(-2) < 41:
raise ValueError(
f"Invalid size of target. Expected at least 41x41, but got {target.size(-1)}x{target.size(-2)}!"
)

per_channel = [_vif_per_channel(preds[:, i, :, :], target[:, i, :, :], sigma_n_sq) for i in range(preds.size(1))]
return reduce(torch.cat(per_channel), "elementwise_mean")
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure
from torchmetrics.image.tv import TotalVariation
from torchmetrics.image.uqi import UniversalImageQualityIndex
from torchmetrics.image.vif import VisualInformationFidelity
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE

__all__ = [
Expand All @@ -36,6 +37,7 @@
"MemorizationInformedFrechetInceptionDistance",
"StructuralSimilarityIndexMeasure",
"UniversalImageQualityIndex",
"VisualInformationFidelity",
"TotalVariation",
]

Expand Down
82 changes: 82 additions & 0 deletions src/torchmetrics/image/vif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch import Tensor, tensor

from torchmetrics.functional.image.vif import _vif_per_channel
from torchmetrics.metric import Metric


class VisualInformationFidelity(Metric):
"""Compute Pixel Based Visual Information Fidelity (VIF_).
As input to ``forward`` and ``update`` the metric accepts the following input
- ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` with H,W ≥ 41
- ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` with H,W ≥ 41
As output of `forward` and `compute` the metric returns the following output
- ``vif-p`` (:class:`~torch.Tensor`): Tensor with vif-p score
Args:
sigma_n_sq: variance of the visual noise
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import VisualInformationFidelity
>>> preds = torch.randn([32, 3, 41, 41])
>>> target = torch.randn([32, 3, 41, 41])
>>> vif = VisualInformationFidelity()
>>> vif(preds, target)
tensor(0.0032)
"""

is_differentiable = True
higher_is_better = True
full_state_update = False

vif_score: Tensor
total: Tensor

def __init__(self, sigma_n_sq: float = 2.0, **kwargs: Any) -> None:
super().__init__(**kwargs)

if not isinstance(sigma_n_sq, float) and not isinstance(sigma_n_sq, int):
raise ValueError(f"Argument `sigma_n_sq` is expected to be a positive float or int, but got {sigma_n_sq}")

if sigma_n_sq < 0:
raise ValueError(f"Argument `sigma_n_sq` is expected to be a positive float or int, but got {sigma_n_sq}")

self.add_state("vif_score", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")
self.sigma_n_sq = sigma_n_sq

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
channels = preds.size(1)
vif_per_channel = [
_vif_per_channel(preds[:, i, :, :], target[:, i, :, :], self.sigma_n_sq) for i in range(channels)
]
vif_per_channel = torch.mean(torch.stack(vif_per_channel), 0) if channels > 1 else torch.cat(vif_per_channel)
self.vif_score += torch.sum(vif_per_channel)
self.total += preds.shape[0]

def compute(self) -> Tensor:
"""Compute vif-p over state."""
return self.vif_score / self.total
62 changes: 62 additions & 0 deletions tests/unittests/image/test_vif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import numpy as np
import pytest
import torch
from sewar.full_ref import vifp
from torchmetrics.functional.image.vif import visual_information_fidelity
from torchmetrics.image.vif import VisualInformationFidelity

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])
_inputs = [
Input(
preds=torch.randint(0, 255, size=(NUM_BATCHES, BATCH_SIZE, channels, 41, 41), dtype=torch.float),
target=torch.randint(0, 255, size=(NUM_BATCHES, BATCH_SIZE, channels, 41, 41), dtype=torch.float),
)
for channels in [1, 3]
]


def _sewar_vif(preds, target, sigma_nsq=2):
preds = torch.movedim(preds, 1, -1)
target = torch.movedim(target, 1, -1)
preds = preds.cpu().numpy()
target = target.cpu().numpy()
vif = [vifp(GT=target[batch], P=preds[batch], sigma_nsq=sigma_nsq) for batch in range(preds.shape[0])]
return np.mean(vif)


@pytest.mark.parametrize("preds, target", [(inputs.preds, inputs.target) for inputs in _inputs])
class TestVIF(MetricTester):
"""Test class for `VisualInformationFidelity` metric."""

atol = 1e-7

@pytest.mark.parametrize("ddp", [True, False])
def test_vif(self, preds, target, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(ddp, preds, target, VisualInformationFidelity, _sewar_vif)

def test_vif_functional(self, preds, target):
"""Test functional implementation of metric."""
self.run_functional_metric_test(preds, target, visual_information_fidelity, _sewar_vif)

0 comments on commit f70b28d

Please sign in to comment.