Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MiFID (Memorization-Informed Frechet Inception Distance) Metric #1578

Closed
Dibz15 opened this issue Mar 2, 2023 · 2 comments · Fixed by #1580
Closed

Adding MiFID (Memorization-Informed Frechet Inception Distance) Metric #1578

Dibz15 opened this issue Mar 2, 2023 · 2 comments · Fixed by #1580
Labels
Milestone

Comments

@Dibz15
Copy link
Contributor

Dibz15 commented Mar 2, 2023

🚀 Feature

The goal of this feature would be to implement the MiFID metric initially proposed in On Training Sample Memorization: Lessons from Benchmarking Generative Modeling with a Large-scale Competition. It is an extension of FID, so it could be implemented similarly.

Motivation

I found myself needing to use this metric with PyTorch, but didn't find any existing implementations that use PyTorch tensors. I've used torchmetrics (including for FID) in the past, so I thought that this metric could be a good addition to the library. Additionally, since FID is already present then I imagined an MiFID implementation should not be too difficult to write as an extension of the existing FID code.

Pitch

I would like to add MiFID to the library as an extension of FID. I've already implemented the metric myself as a Metric, and ran some basic tests against the original implementation from the source repo (which uses NumPy for matrix operations, found here).

Although I've already got the metric working (and I'll paste my class code below), I would need some help integrating it into the library and adding tests in order to follow the contribution guidelines and the steps given at the bottom of the Metric page.

Alternatives

The current alternative us just using the NumPy-based solution given in the original source repository.

Additional context

Here is my current code for this feature:

from numpy.lib.type_check import real
from copy import deepcopy
from typing import Any, List, Optional, Union
from torch import Tensor
from torch.autograd import Function
from torch.nn import Module

from torchmetrics.image.fid import NoTrainInceptionV3, MatrixSquareRoot, sqrtm, _compute_fid
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
from torchmetrics.utilities.imports import _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
import numpy as np
import scipy

if _SCIPY_AVAILABLE:
    import scipy

def _normalize_rows(x: torch.Tensor):
    """
    function that normalizes each row of the matrix x to have unit length.

    Args:
     ``x``: A PyTorch tensor of shape (n, m)

    Returns:
     ``x``: The normalized (by row) PyTorch tensor.
    """
    return x / torch.norm(x, dim=1, keepdim=True)

def _distance_thresholding(d : torch.Tensor, eps=0.1):
    if d < eps:
        return d
    else:
        return 1

def _compute_cosine_distance(features1 : Tensor, features2: Tensor):
    features1_nozero = features1[torch.sum(features1, dim=1) != 0]
    features2_nozero = features2[torch.sum(features2, dim=1) != 0]
    norm_f1 = _normalize_rows(features1_nozero)
    norm_f2 = _normalize_rows(features2_nozero)

    d = 1.0 - torch.abs(torch.matmul(norm_f1, norm_f2.t()))
    mean_min_d = torch.mean(torch.min(d, dim=1).values)
    return mean_min_d

def _compute_mifid(mu1: Tensor, sigma1: Tensor, features1: Tensor, mu2 : Tensor, sigma2: Tensor, features2: Tensor):
    fid_value = _compute_fid(mu1, sigma1, mu2, sigma2)
    distance = _compute_cosine_distance(features1, features2)
    distance_thr = _distance_thresholding(distance, eps=0.1)
    mifid = fid_value / (distance_thr + 10e-15)
    # print("FID_public: ", fid_value, "distance_public: ", distance_thr, "multiplied_public: ", mifid)
    return mifid

class MemorizationInformedFrechetInceptionDistance(Metric):
    higher_is_better: bool = False
    is_differentiable: bool = False
    full_state_update: bool = False

    real_features_sum: Tensor
    real_features_cov_sum: Tensor
    real_features_num_samples: Tensor
    real_features_stacked: Tensor

    fake_features_sum: Tensor
    fake_features_cov_sum: Tensor
    fake_features_num_samples: Tensor
    fake_features_stacked: Tensor

    def __init__(
        self,
        feature: Union[int, Module] = 2048,
        reset_real_features: bool = True,
        normalize: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        if isinstance(feature, int):
            num_features = feature
            if not _TORCH_FIDELITY_AVAILABLE:
                raise ModuleNotFoundError(
                    "MemorizationInformedFrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
                    " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
                )
            valid_int_input = [64, 192, 768, 2048]
            if feature not in valid_int_input:
                raise ValueError(
                    f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
                )

            self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])

        elif isinstance(feature, Module):
            self.inception = feature
            dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8, device=self.inception.device)
            num_features = self.inception(dummy_image).shape[-1]
        else:
            raise TypeError("Got unknown input to argument `feature`")

        if not isinstance(reset_real_features, bool):
            raise ValueError("Argument `reset_real_features` expected to be a bool")
        self.reset_real_features = reset_real_features

        if not isinstance(normalize, bool):
            raise ValueError("Argument `normalize` expected to be a bool")
        self.normalize = normalize

        mx_nb_feets = (num_features, num_features)

        self.add_state("real_features_stacked", torch.zeros((0, num_features)).double(), dist_reduce_fx='cat')
        self.add_state("fake_features_stacked", torch.zeros((0, num_features)).double(), dist_reduce_fx='cat')

    def update(self, imgs: Tensor, real: bool) -> None:  # type: ignore
        """Update the state with extracted features."""
        imgs = (imgs * 255).byte() if self.normalize else imgs
        features = self.inception(imgs)
        self.orig_dtype = features.dtype
        features = features.double()

        if features.dim() == 1:
            features = features.unsqueeze(0)   
        if real:
            self.real_features_stacked = torch.cat((self.real_features_stacked, features), dim=0)
        else:
            self.fake_features_stacked = torch.cat((self.fake_features_stacked, features), dim=0)

    def compute(self) -> Tensor:
        """Calculate MiFID score based on accumulated extracted features from the two distributions."""
 
        mean_real = torch.mean(self.real_features_stacked, dim=0).unsqueeze(0)
        mean_fake = torch.mean(self.fake_features_stacked, dim=0).unsqueeze(0)

        cov_real = torch.cov(self.real_features_stacked.t())
        cov_fake = torch.cov(self.fake_features_stacked.t())

        return _compute_mifid(mean_real.squeeze(0), cov_real, self.real_features_stacked, mean_fake.squeeze(0), cov_fake, self.fake_features_stacked).to(self.orig_dtype)

    def to(self, device):
        self.inception = self.inception.to(device)
        return super().to(device)

    def reset(self) -> None:
        if not self.reset_real_features:
            real_features_stacked = deepcopy(self.real_features_stacked)
            super().reset()
            self.real_features_stacked = real_features_stacked
        else:
            super().reset()

A couple of notes about this:

  • I added device=self.inception.device to the dummy_image test from the original FID because I found this line gave me errors when my Inception network was already on a CUDA device, but the Metric wasn't yet (as it is just being initialized). This happened with the original FID class as well. I don't know if this works with the broader torchmetrics library but was a workaround for me.
  • The MiFID metric utilizes the cosine distance between the two feature vectors. I wasn't sure how to break this down so that all the features don't need to be stacked together and used to calculate the distance under 'compute'. Perhaps there's a way to avoid storing all the stacked features, I am hoping someone here would know if that could be done.
  • Additionally, I haven't yet tested this extensively. But I have found that it gives values that are very similar to those from the source NumPy solution. I just found myself implementing this code for my own project that I'm working on, and wanted to contribute it here :).
@Dibz15 Dibz15 added the enhancement New feature or request label Mar 2, 2023
@github-actions
Copy link

github-actions bot commented Mar 2, 2023

Hi! thanks for your contribution!, great first issue!

@stancld
Copy link
Contributor

stancld commented Mar 2, 2023

Hi @Dibz15, thanks for a nice initiative.I think you can try to open a draft PR and we can help you there to finish it! :]

@stancld stancld added this to the future milestone Mar 2, 2023
@Dibz15 Dibz15 mentioned this issue Mar 2, 2023
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
@Borda @Dibz15 @stancld and others