You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found myself needing to use this metric with PyTorch, but didn't find any existing implementations that use PyTorch tensors. I've used torchmetrics (including for FID) in the past, so I thought that this metric could be a good addition to the library. Additionally, since FID is already present then I imagined an MiFID implementation should not be too difficult to write as an extension of the existing FID code.
Pitch
I would like to add MiFID to the library as an extension of FID. I've already implemented the metric myself as a Metric, and ran some basic tests against the original implementation from the source repo (which uses NumPy for matrix operations, found here).
Although I've already got the metric working (and I'll paste my class code below), I would need some help integrating it into the library and adding tests in order to follow the contribution guidelines and the steps given at the bottom of the Metric page.
Alternatives
The current alternative us just using the NumPy-based solution given in the original source repository.
Additional context
Here is my current code for this feature:
fromnumpy.lib.type_checkimportrealfromcopyimportdeepcopyfromtypingimportAny, List, Optional, UnionfromtorchimportTensorfromtorch.autogradimportFunctionfromtorch.nnimportModulefromtorchmetrics.image.fidimportNoTrainInceptionV3, MatrixSquareRoot, sqrtm, _compute_fidfromtorchmetrics.metricimportMetricfromtorchmetrics.utilitiesimportrank_zero_infofromtorchmetrics.utilities.importsimport_SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLEimportnumpyasnpimportscipyif_SCIPY_AVAILABLE:
importscipydef_normalize_rows(x: torch.Tensor):
""" function that normalizes each row of the matrix x to have unit length. Args: ``x``: A PyTorch tensor of shape (n, m) Returns: ``x``: The normalized (by row) PyTorch tensor. """returnx/torch.norm(x, dim=1, keepdim=True)
def_distance_thresholding(d : torch.Tensor, eps=0.1):
ifd<eps:
returndelse:
return1def_compute_cosine_distance(features1 : Tensor, features2: Tensor):
features1_nozero=features1[torch.sum(features1, dim=1) !=0]
features2_nozero=features2[torch.sum(features2, dim=1) !=0]
norm_f1=_normalize_rows(features1_nozero)
norm_f2=_normalize_rows(features2_nozero)
d=1.0-torch.abs(torch.matmul(norm_f1, norm_f2.t()))
mean_min_d=torch.mean(torch.min(d, dim=1).values)
returnmean_min_ddef_compute_mifid(mu1: Tensor, sigma1: Tensor, features1: Tensor, mu2 : Tensor, sigma2: Tensor, features2: Tensor):
fid_value=_compute_fid(mu1, sigma1, mu2, sigma2)
distance=_compute_cosine_distance(features1, features2)
distance_thr=_distance_thresholding(distance, eps=0.1)
mifid=fid_value/ (distance_thr+10e-15)
# print("FID_public: ", fid_value, "distance_public: ", distance_thr, "multiplied_public: ", mifid)returnmifidclassMemorizationInformedFrechetInceptionDistance(Metric):
higher_is_better: bool=Falseis_differentiable: bool=Falsefull_state_update: bool=Falsereal_features_sum: Tensorreal_features_cov_sum: Tensorreal_features_num_samples: Tensorreal_features_stacked: Tensorfake_features_sum: Tensorfake_features_cov_sum: Tensorfake_features_num_samples: Tensorfake_features_stacked: Tensordef__init__(
self,
feature: Union[int, Module] =2048,
reset_real_features: bool=True,
normalize: bool=False,
**kwargs: Any,
) ->None:
super().__init__(**kwargs)
ifisinstance(feature, int):
num_features=featureifnot_TORCH_FIDELITY_AVAILABLE:
raiseModuleNotFoundError(
"MemorizationInformedFrechetInceptionDistance metric requires that `Torch-fidelity` is installed."" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input= [64, 192, 768, 2048]
iffeaturenotinvalid_int_input:
raiseValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
)
self.inception=NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elifisinstance(feature, Module):
self.inception=featuredummy_image=torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8, device=self.inception.device)
num_features=self.inception(dummy_image).shape[-1]
else:
raiseTypeError("Got unknown input to argument `feature`")
ifnotisinstance(reset_real_features, bool):
raiseValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features=reset_real_featuresifnotisinstance(normalize, bool):
raiseValueError("Argument `normalize` expected to be a bool")
self.normalize=normalizemx_nb_feets= (num_features, num_features)
self.add_state("real_features_stacked", torch.zeros((0, num_features)).double(), dist_reduce_fx='cat')
self.add_state("fake_features_stacked", torch.zeros((0, num_features)).double(), dist_reduce_fx='cat')
defupdate(self, imgs: Tensor, real: bool) ->None: # type: ignore"""Update the state with extracted features."""imgs= (imgs*255).byte() ifself.normalizeelseimgsfeatures=self.inception(imgs)
self.orig_dtype=features.dtypefeatures=features.double()
iffeatures.dim() ==1:
features=features.unsqueeze(0)
ifreal:
self.real_features_stacked=torch.cat((self.real_features_stacked, features), dim=0)
else:
self.fake_features_stacked=torch.cat((self.fake_features_stacked, features), dim=0)
defcompute(self) ->Tensor:
"""Calculate MiFID score based on accumulated extracted features from the two distributions."""mean_real=torch.mean(self.real_features_stacked, dim=0).unsqueeze(0)
mean_fake=torch.mean(self.fake_features_stacked, dim=0).unsqueeze(0)
cov_real=torch.cov(self.real_features_stacked.t())
cov_fake=torch.cov(self.fake_features_stacked.t())
return_compute_mifid(mean_real.squeeze(0), cov_real, self.real_features_stacked, mean_fake.squeeze(0), cov_fake, self.fake_features_stacked).to(self.orig_dtype)
defto(self, device):
self.inception=self.inception.to(device)
returnsuper().to(device)
defreset(self) ->None:
ifnotself.reset_real_features:
real_features_stacked=deepcopy(self.real_features_stacked)
super().reset()
self.real_features_stacked=real_features_stackedelse:
super().reset()
A couple of notes about this:
I added device=self.inception.device to the dummy_image test from the original FID because I found this line gave me errors when my Inception network was already on a CUDA device, but the Metric wasn't yet (as it is just being initialized). This happened with the original FID class as well. I don't know if this works with the broader torchmetrics library but was a workaround for me.
The MiFID metric utilizes the cosine distance between the two feature vectors. I wasn't sure how to break this down so that all the features don't need to be stacked together and used to calculate the distance under 'compute'. Perhaps there's a way to avoid storing all the stacked features, I am hoping someone here would know if that could be done.
Additionally, I haven't yet tested this extensively. But I have found that it gives values that are very similar to those from the source NumPy solution. I just found myself implementing this code for my own project that I'm working on, and wanted to contribute it here :).
The text was updated successfully, but these errors were encountered:
🚀 Feature
The goal of this feature would be to implement the MiFID metric initially proposed in On Training Sample Memorization: Lessons from Benchmarking Generative Modeling with a Large-scale Competition. It is an extension of FID, so it could be implemented similarly.
Motivation
I found myself needing to use this metric with PyTorch, but didn't find any existing implementations that use PyTorch tensors. I've used torchmetrics (including for FID) in the past, so I thought that this metric could be a good addition to the library. Additionally, since FID is already present then I imagined an MiFID implementation should not be too difficult to write as an extension of the existing FID code.
Pitch
I would like to add MiFID to the library as an extension of FID. I've already implemented the metric myself as a Metric, and ran some basic tests against the original implementation from the source repo (which uses NumPy for matrix operations, found here).
Although I've already got the metric working (and I'll paste my class code below), I would need some help integrating it into the library and adding tests in order to follow the contribution guidelines and the steps given at the bottom of the Metric page.
Alternatives
The current alternative us just using the NumPy-based solution given in the original source repository.
Additional context
Here is my current code for this feature:
A couple of notes about this:
device=self.inception.device
to the dummy_image test from the original FID because I found this line gave me errors when my Inception network was already on a CUDA device, but the Metric wasn't yet (as it is just being initialized). This happened with the original FID class as well. I don't know if this works with the broader torchmetrics library but was a workaround for me.The text was updated successfully, but these errors were encountered: