-
Notifications
You must be signed in to change notification settings - Fork 405
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into bugfix/metric_collection_and_aggregation
- Loading branch information
Showing
8 changed files
with
287 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
.. customcarditem:: | ||
:header: Visual Information Fidelity (VIF) | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg | ||
:tags: Image | ||
|
||
.. include:: ../links.rst | ||
|
||
################################# | ||
Visual Information Fidelity (VIF) | ||
################################# | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.image.VisualInformationFidelity | ||
:noindex: | ||
:exclude-members: update, compute | ||
|
||
Functional Interface | ||
____________________ | ||
|
||
.. autofunction:: torchmetrics.functional.image.visual_information_fidelity | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
from torch import Tensor | ||
from torch.nn.functional import conv2d | ||
|
||
from torchmetrics.utilities.distributed import reduce | ||
|
||
|
||
def _filter(win_size: float, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: | ||
# This code is inspired by | ||
# https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/utils.py#L45 | ||
# https://github.com/photosynthesis-team/piq/blob/01e16b7d8c76bc8765fb6a69560d806148b8046a/piq/functional/filters.py#L38 | ||
# Both links do the same, but the second one is cleaner | ||
coords = torch.arange(win_size, dtype=dtype, device=device) - (win_size - 1) / 2 | ||
g = coords**2 | ||
g = torch.exp(-(g.unsqueeze(0) + g.unsqueeze(1)) / (2.0 * sigma**2)) | ||
g /= torch.sum(g) | ||
return g | ||
|
||
|
||
def _vif_per_channel(preds: Tensor, target: Tensor, sigma_n_sq: float) -> Tensor: | ||
dtype = preds.dtype | ||
device = preds.device | ||
|
||
preds = preds.unsqueeze(1) # Add channel dimension | ||
target = target.unsqueeze(1) | ||
# Constant for numerical stability | ||
eps = torch.tensor(1e-10, dtype=dtype, device=device) | ||
|
||
sigma_n_sq = torch.tensor(sigma_n_sq, dtype=dtype, device=device) | ||
|
||
preds_vif, target_vif = torch.zeros(1, dtype=dtype, device=device), torch.zeros(1, dtype=dtype, device=device) | ||
for scale in range(4): | ||
n = 2.0 ** (4 - scale) + 1 | ||
kernel = _filter(n, n / 5, dtype=dtype, device=device)[None, None, :] | ||
|
||
if scale > 0: | ||
target = conv2d(target, kernel)[:, :, ::2, ::2] | ||
preds = conv2d(preds, kernel)[:, :, ::2, ::2] | ||
|
||
mu_target = conv2d(target, kernel) | ||
mu_preds = conv2d(preds, kernel) | ||
mu_target_sq = mu_target**2 | ||
mu_preds_sq = mu_preds**2 | ||
mu_target_preds = mu_target * mu_preds | ||
|
||
sigma_target_sq = torch.clamp(conv2d(target**2, kernel) - mu_target_sq, min=0.0) | ||
sigma_preds_sq = torch.clamp(conv2d(preds**2, kernel) - mu_preds_sq, min=0.0) | ||
sigma_target_preds = conv2d(target * preds, kernel) - mu_target_preds | ||
|
||
g = sigma_target_preds / (sigma_target_sq + eps) | ||
sigma_v_sq = sigma_preds_sq - g * sigma_target_preds | ||
|
||
mask = sigma_target_sq < eps | ||
g[mask] = 0 | ||
sigma_v_sq[mask] = sigma_preds_sq[mask] | ||
sigma_target_sq[mask] = 0 | ||
|
||
mask = sigma_preds_sq < eps | ||
g[mask] = 0 | ||
sigma_v_sq[mask] = 0 | ||
|
||
mask = g < 0 | ||
sigma_v_sq[mask] = sigma_preds_sq[mask] | ||
g[mask] = 0 | ||
sigma_v_sq = torch.clamp(sigma_v_sq, min=eps) | ||
|
||
preds_vif_scale = torch.log10(1.0 + (g**2.0) * sigma_target_sq / (sigma_v_sq + sigma_n_sq)) | ||
preds_vif = preds_vif + torch.sum(preds_vif_scale, dim=[1, 2, 3]) | ||
target_vif = target_vif + torch.sum(torch.log10(1.0 + sigma_target_sq / sigma_n_sq), dim=[1, 2, 3]) | ||
return preds_vif / target_vif | ||
|
||
|
||
def visual_information_fidelity(preds: Tensor, target: Tensor, sigma_n_sq: float = 2.0) -> Tensor: | ||
"""Compute Pixel Based Visual Information Fidelity (VIF_). | ||
Args: | ||
preds: predicted images of shape ``(N,C,H,W)``. ``(H, W)`` has to be at least ``(41, 41)``. | ||
target: ground truth images of shape ``(N,C,H,W)``. ``(H, W)`` has to be at least ``(41, 41)`` | ||
sigma_n_sq: variance of the visual noise | ||
Return: | ||
Tensor with vif-p score | ||
Raises: | ||
ValueError: | ||
If ``data_range`` is neither a ``tuple`` nor a ``float`` | ||
""" | ||
# This code is inspired by | ||
# https://github.com/photosynthesis-team/piq/blob/01e16b7d8c76bc8765fb6a69560d806148b8046a/piq/vif.py and | ||
# https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/full_ref.py#L357. | ||
|
||
if preds.size(-1) < 41 or preds.size(-2) < 41: | ||
raise ValueError(f"Invalid size of preds. Expected at least 41x41, but got {preds.size(-1)}x{preds.size(-2)}!") | ||
|
||
if target.size(-1) < 41 or target.size(-2) < 41: | ||
raise ValueError( | ||
f"Invalid size of target. Expected at least 41x41, but got {target.size(-1)}x{target.size(-2)}!" | ||
) | ||
|
||
per_channel = [_vif_per_channel(preds[:, i, :, :], target[:, i, :, :], sigma_n_sq) for i in range(preds.size(1))] | ||
return reduce(torch.cat(per_channel), "elementwise_mean") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any | ||
|
||
import torch | ||
from torch import Tensor, tensor | ||
|
||
from torchmetrics.functional.image.vif import _vif_per_channel | ||
from torchmetrics.metric import Metric | ||
|
||
|
||
class VisualInformationFidelity(Metric): | ||
"""Compute Pixel Based Visual Information Fidelity (VIF_). | ||
As input to ``forward`` and ``update`` the metric accepts the following input | ||
- ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` with H,W ≥ 41 | ||
- ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` with H,W ≥ 41 | ||
As output of `forward` and `compute` the metric returns the following output | ||
- ``vif-p`` (:class:`~torch.Tensor`): Tensor with vif-p score | ||
Args: | ||
sigma_n_sq: variance of the visual noise | ||
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
Example: | ||
>>> import torch | ||
>>> _ = torch.manual_seed(42) | ||
>>> from torchmetrics.image import VisualInformationFidelity | ||
>>> preds = torch.randn([32, 3, 41, 41]) | ||
>>> target = torch.randn([32, 3, 41, 41]) | ||
>>> vif = VisualInformationFidelity() | ||
>>> vif(preds, target) | ||
tensor(0.0032) | ||
""" | ||
|
||
is_differentiable = True | ||
higher_is_better = True | ||
full_state_update = False | ||
|
||
vif_score: Tensor | ||
total: Tensor | ||
|
||
def __init__(self, sigma_n_sq: float = 2.0, **kwargs: Any) -> None: | ||
super().__init__(**kwargs) | ||
|
||
if not isinstance(sigma_n_sq, float) and not isinstance(sigma_n_sq, int): | ||
raise ValueError(f"Argument `sigma_n_sq` is expected to be a positive float or int, but got {sigma_n_sq}") | ||
|
||
if sigma_n_sq < 0: | ||
raise ValueError(f"Argument `sigma_n_sq` is expected to be a positive float or int, but got {sigma_n_sq}") | ||
|
||
self.add_state("vif_score", default=tensor(0.0), dist_reduce_fx="sum") | ||
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") | ||
self.sigma_n_sq = sigma_n_sq | ||
|
||
def update(self, preds: Tensor, target: Tensor) -> None: | ||
"""Update state with predictions and targets.""" | ||
channels = preds.size(1) | ||
vif_per_channel = [ | ||
_vif_per_channel(preds[:, i, :, :], target[:, i, :, :], self.sigma_n_sq) for i in range(channels) | ||
] | ||
vif_per_channel = torch.mean(torch.stack(vif_per_channel), 0) if channels > 1 else torch.cat(vif_per_channel) | ||
self.vif_score += torch.sum(vif_per_channel) | ||
self.total += preds.shape[0] | ||
|
||
def compute(self) -> Tensor: | ||
"""Compute vif-p over state.""" | ||
return self.vif_score / self.total |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from sewar.full_ref import vifp | ||
from torchmetrics.functional.image.vif import visual_information_fidelity | ||
from torchmetrics.image.vif import VisualInformationFidelity | ||
|
||
from unittests import BATCH_SIZE, NUM_BATCHES | ||
from unittests.helpers import seed_all | ||
from unittests.helpers.testers import MetricTester | ||
|
||
seed_all(42) | ||
|
||
Input = namedtuple("Input", ["preds", "target"]) | ||
_inputs = [ | ||
Input( | ||
preds=torch.randint(0, 255, size=(NUM_BATCHES, BATCH_SIZE, channels, 41, 41), dtype=torch.float), | ||
target=torch.randint(0, 255, size=(NUM_BATCHES, BATCH_SIZE, channels, 41, 41), dtype=torch.float), | ||
) | ||
for channels in [1, 3] | ||
] | ||
|
||
|
||
def _sewar_vif(preds, target, sigma_nsq=2): | ||
preds = torch.movedim(preds, 1, -1) | ||
target = torch.movedim(target, 1, -1) | ||
preds = preds.cpu().numpy() | ||
target = target.cpu().numpy() | ||
vif = [vifp(GT=target[batch], P=preds[batch], sigma_nsq=sigma_nsq) for batch in range(preds.shape[0])] | ||
return np.mean(vif) | ||
|
||
|
||
@pytest.mark.parametrize("preds, target", [(inputs.preds, inputs.target) for inputs in _inputs]) | ||
class TestVIF(MetricTester): | ||
"""Test class for `VisualInformationFidelity` metric.""" | ||
|
||
atol = 1e-7 | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
def test_vif(self, preds, target, ddp): | ||
"""Test class implementation of metric.""" | ||
self.run_class_metric_test(ddp, preds, target, VisualInformationFidelity, _sewar_vif) | ||
|
||
def test_vif_functional(self, preds, target): | ||
"""Test functional implementation of metric.""" | ||
self.run_functional_metric_test(preds, target, visual_information_fidelity, _sewar_vif) |