File tree Expand file tree Collapse file tree 1 file changed +42
-1
lines changed Expand file tree Collapse file tree 1 file changed +42
-1
lines changed Original file line number Diff line number Diff line change @@ -67,7 +67,7 @@ class FID(_BaseInceptionMetric):
67
67
68
68
Remark:
69
69
70
- This implementation is inspired by pytorch_fid package which can be found `here`__
70
+ This implementation is inspired by ` pytorch_fid` package which can be found `here`__
71
71
72
72
__ https://github.com/mseitzer/pytorch-fid
73
73
@@ -114,6 +114,47 @@ class FID(_BaseInceptionMetric):
114
114
115
115
0.0
116
116
117
+ .. note::
118
+
119
+ The default `torchvision` model used is InceptionV3 pretrained on ImageNet.
120
+ This can lead to differences in results with `pytorch_fid`. To find comparable results,
121
+ the following model wrapper should be used:
122
+
123
+ .. code::
124
+
125
+ import torch.nn as nn
126
+
127
+ # wrapper class as feature_extractor
128
+ class WrapperInceptionV3(nn.Module):
129
+
130
+ def __init__(self, fid_incv3):
131
+ super().__init__()
132
+ self.fid_incv3 = fid_incv3
133
+
134
+ @torch.no_grad()
135
+ def forward(self, x):
136
+ y = self.fid_incv3(x)
137
+ y = y[0]
138
+ y = y[:, :, 0, 0]
139
+ return y
140
+
141
+ # use cpu rather than cuda to get comparable results
142
+ device = "cpu"
143
+
144
+ # pytorch_fid model
145
+ dims = 2048
146
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
147
+ model = InceptionV3([block_idx]).to(device)
148
+
149
+ # wrapper model to pytorch_fid model
150
+ wrapper_model = WrapperInceptionV3(model)
151
+ wrapper_model.eval();
152
+
153
+ # comparable metric
154
+ pytorch_fid_metric = FID(num_features=dims, feature_extractor=wrapper_model)
155
+
156
+ Important, `pytorch_fid` results depend on the batch size if the device is `cuda`.
157
+
117
158
.. versionadded:: 0.4.6
118
159
"""
119
160
You can’t perform that action at this time.
0 commit comments