Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MiFID implementation #1580

Merged
merged 29 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
47c4a9d
Add initial MiFID implementation
Dibz15 Mar 2, 2023
3e1bd85
Merge branch 'master' into mifid
Borda Jun 28, 2023
8bcf514
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2023
3d19efe
docs
SkafteNicki Jun 29, 2023
26f1125
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2023
73aadb2
implementation
SkafteNicki Jun 29, 2023
d134cb6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2023
c850ae2
tests
SkafteNicki Jun 29, 2023
6bbe432
Merge branch 'master' into mifid
SkafteNicki Jun 29, 2023
6f00517
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2023
1048730
update fid
SkafteNicki Jun 29, 2023
f591440
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2023
8979c6c
plotting
SkafteNicki Jun 29, 2023
71631b1
Merge branch 'mifid' of https://github.com/Dibz15/metrics into mifid
SkafteNicki Jun 29, 2023
7a89bdd
changelog
SkafteNicki Jun 29, 2023
1cbbfae
fix ruff
SkafteNicki Jun 29, 2023
1c99bcb
skip on missing import
SkafteNicki Jun 29, 2023
b6096c7
Merge branch 'master' into mifid
mergify[bot] Jun 29, 2023
222bded
fix docs formatting
SkafteNicki Jun 29, 2023
cd49dd1
fix typing issues
SkafteNicki Jun 29, 2023
6ba16d9
Merge branch 'mifid' of https://github.com/Dibz15/metrics into mifid
SkafteNicki Jun 29, 2023
c7aed65
missing doc link
SkafteNicki Jun 29, 2023
fa58a06
fix tests
SkafteNicki Jun 29, 2023
9af0832
Merge branch 'master' into mifid
justusschock Jun 30, 2023
e8a1df6
up lower torch version
SkafteNicki Jul 1, 2023
881bc87
Merge branch 'mifid' of https://github.com/Dibz15/metrics into mifid
SkafteNicki Jul 1, 2023
76e8568
more skip
SkafteNicki Jul 1, 2023
1db5ad1
Merge branch 'master' into mifid
Borda Jul 3, 2023
70197ce
Merge branch 'master' into mifid
mergify[bot] Jul 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RelativeSquaredError` metric to regression package ([#1765](https://github.com/Lightning-AI/torchmetrics/pull/1765))


- Added `MemorizationInformedFrechetInceptionDistance` metric to image package ([#1580](https://github.com/Lightning-AI/torchmetrics/pull/1580))


### Changed

- Changed `permutation_invariant_training` to allow using a `'permutation-wise'` metric function ([#1794](https://github.com/Lightning-AI/metrics/pull/1794))
Expand Down
17 changes: 17 additions & 0 deletions docs/source/image/mifid.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. customcarditem::
:header: Memorization-Informed Frechet Inception Distance (MiFID)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

########################################################
Memorization-Informed Frechet Inception Distance (MiFID)
########################################################

Module Interface
________________

.. autoclass:: torchmetrics.image.mifid.MemorizationInformedFrechetInceptionDistance
:noindex:
:exclude-members: update, compute
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
from torchmetrics.image.psnr import PeakSignalNoiseRatio
from torchmetrics.image.psnrb import PeakSignalNoiseRatioWithBlockedEffect
from torchmetrics.image.rase import RelativeAverageSpectralError
Expand All @@ -32,6 +33,7 @@
"RootMeanSquaredErrorUsingSlidingWindow",
"SpectralAngleMapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"MemorizationInformedFrechetInceptionDistance",
"StructuralSimilarityIndexMeasure",
"UniversalImageQualityIndex",
"TotalVariation",
Expand Down
280 changes: 280 additions & 0 deletions src/torchmetrics/image/mifid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.nn import Module

from torchmetrics.image.fid import NoTrainInceptionV3, _compute_fid
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MemorizationInformedFrechetInceptionDistance.plot"]

__doctest_requires__ = {
("MemorizationInformedFrechetInceptionDistance", "MemorizationInformedFrechetInceptionDistance.plot"): [
"torch_fidelity"
]
}


def _compute_cosine_distance(features1: Tensor, features2: Tensor, cosine_distance_eps: float = 0.1) -> Tensor:
"""Compute the cosine distance between two sets of features."""
features1_nozero = features1[torch.sum(features1, dim=1) != 0]
features2_nozero = features2[torch.sum(features2, dim=1) != 0]

# normalize
norm_f1 = features1_nozero / torch.norm(features1_nozero, dim=1, keepdim=True)
norm_f2 = features2_nozero / torch.norm(features2_nozero, dim=1, keepdim=True)

d = 1.0 - torch.abs(torch.matmul(norm_f1, norm_f2.t()))
mean_min_d = torch.mean(d.min(dim=1).values)
return mean_min_d if mean_min_d < cosine_distance_eps else torch.ones_like(mean_min_d)


def _mifid_compute(
mu1: Tensor,
sigma1: Tensor,
features1: Tensor,
mu2: Tensor,
sigma2: Tensor,
features2: Tensor,
cosine_distance_eps: float = 0.1,
) -> Tensor:
"""Compute MIFID score given two sets of features and their statistics."""
fid_value = _compute_fid(mu1, sigma1, mu2, sigma2)
distance = _compute_cosine_distance(features1, features2, cosine_distance_eps)
return fid_value / (distance + 10e-15)


class MemorizationInformedFrechetInceptionDistance(Metric):
r"""Calculate Memorization-Informed Frechet Inception Distance (MIFID_).

MIFID is a improved variation of the Frechet Inception Distance (FID_) that penalizes memorization of the training
set by the generator. It is calculated as

.. math::
MIFID = \frac{FID(F_{real}, F_{fake})}{M(F_{real}, F_{fake})}

where :math:`FID` is the normal FID score and :math:`M` is the memorization penalty. The memorization penalty
essentially corresponds to the average minimum cosine distance between the features of the real and fake
distribution.

Using the default feature extraction (Inception v3 using the original weights from `fid ref2`_), the input is
expected to be mini-batches of 3-channel RGB images of shape ``(3 x H x W)``. If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.

.. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install
torchmetrics[image]`` or ``pip install scipy``

.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``

As input to ``forward`` and ``update`` the metric accepts the following input

- ``imgs`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with
- ``real`` (:class:`~bool`): bool indicating if ``imgs`` belong to the real or the fake distribution

As output of `forward` and `compute` the metric returns the following output

- ``fid`` (:class:`~torch.Tensor`): float scalar tensor with mean FID value over samples

Args:
feature:
Either an integer or ``nn.Module``:
- an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following:
64, 192, 768, 2048
- an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
an ``(N,d)`` matrix where ``N`` is the batch size and ``d`` is the feature size.
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
cosine_distance_eps: Epsilon value for the cosine distance. If the cosine distance is larger than this value
it is set to 1 and thus ignored in the MIFID calculation.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``feature`` is set to an ``int`` and ``torch-fidelity`` is not installed
ValueError:
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
TypeError:
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
ValueError:
If ``reset_real_features`` is not an ``bool``

Example::
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> mifid.update(imgs_dist1, real=True)
>>> mifid.update(imgs_dist2, real=False)
>>> mifid.compute()
tensor(2959.7734)
"""
higher_is_better: bool = False
is_differentiable: bool = False
full_state_update: bool = False

real_features_stacked: Tensor
fake_features_stacked: Tensor

def __init__(
self,
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
cosine_distance_eps: float = 0.1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

if isinstance(feature, int):
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"MemorizationInformedFrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = [64, 192, 768, 2048]
if feature not in valid_int_input:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
)

self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])

elif isinstance(feature, Module):
self.inception = feature
else:
raise TypeError("Got unknown input to argument `feature`")

if not isinstance(reset_real_features, bool):
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

if not (isinstance(cosine_distance_eps, float) and 1 >= cosine_distance_eps > 0):
raise ValueError("Argument `cosine_distance_eps` expected to be a float greater than 0 and less than 1")
self.cosine_distance_eps = cosine_distance_eps

# states for extracted features
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)

def update(self, imgs: Tensor, real: bool) -> None:
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()

if real:
self.real_features.append(features)
else:
self.fake_features.append(features)

def compute(self) -> Tensor:
"""Calculate FID score based on accumulated extracted features from the two distributions."""
real_features = dim_zero_cat(self.real_features)
fake_features = dim_zero_cat(self.fake_features)

mean_real, mean_fake = torch.mean(real_features, dim=0), torch.mean(fake_features, dim=0)
cov_real, cov_fake = torch.cov(real_features.t()), torch.cov(fake_features.t())

return _mifid_compute(
mean_real,
cov_real,
real_features,
mean_fake,
cov_fake,
fake_features,
cosine_distance_eps=self.cosine_distance_eps,
).to(self.orig_dtype)

def reset(self) -> None:
"""Reset metric states."""
if not self.reset_real_features:
# remove temporarily to avoid resetting
value = self._defaults.pop("real_features")
super().reset()
self._defaults["real_features"] = value
else:
super().reset()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> metric = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> metric.update(imgs_dist1, real=True)
>>> metric.update(imgs_dist2, real=False)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> metric = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> values = [ ]
>>> for _ in range(3):
... metric.update(imgs_dist1(), real=True)
... metric.update(imgs_dist2(), real=False)
... values.append(metric.compute())
... metric.reset()
>>> fig_, ax_ = metric.plot(values)


"""
return self._plot(val, ax)
27 changes: 14 additions & 13 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise
from functools import partial

import pytest
import torch
Expand Down Expand Up @@ -119,8 +120,7 @@ def test_compare_fid(tmpdir, equal_size, feature=768):

metric = FrechetInceptionDistance(feature=feature).cuda()

n = 100
m = 100 if equal_size else 90
n, m = 100, 100 if equal_size else 90

# Generate some synthetic data
torch.manual_seed(42)
Expand Down Expand Up @@ -179,20 +179,21 @@ def test_reset_real_features_arg(reset_real_features):


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
def test_normalize_arg_true():
@pytest.mark.parametrize("normalize", [True, False])
def test_normalize_arg(normalize):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = FrechetInceptionDistance(normalize=True)
with does_not_raise():
metric.update(img, real=True)

metric = FrechetInceptionDistance(normalize=normalize)

context = (
partial(
pytest.raises, expected_exception=ValueError, match="Expecting image as torch.Tensor with dtype=torch.uint8"
)
if not normalize
else does_not_raise
)

@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
def test_normalize_arg_false():
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = FrechetInceptionDistance(normalize=False)
with pytest.raises(ValueError, match="Expecting image as torch.Tensor with dtype=torch.uint8"):
with context():
metric.update(img, real=True)


Expand Down
Loading