Skip to content

Commit 4759e01

Browse files
sdesrozisDesroziers
andauthored
[skip ci] Add doctest for GAN metrics (#2349)
* doctest for gan metrics * fix lint Co-authored-by: Desroziers <sylvain.desroziers@michelin.com>
1 parent b20677f commit 4759e01

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

ignite/metrics/gan/fid.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,28 @@ class FID(_BaseInceptionMetric):
9595
non-blocking. By default, CPU.
9696
9797
Examples:
98+
9899
.. code-block:: python
99100
100-
import torch
101-
from ignite.metric.gan import FID
101+
metric = FID()
102+
metric.attach(default_evaluator, "fid")
103+
y_true = torch.rand(10, 3, 299, 299)
104+
y_pred = torch.rand(10, 3, 299, 299)
105+
state = default_evaluator.run([[y_pred, y_true]])
106+
print(state.metrics["fid"])
107+
108+
.. testcode::
109+
110+
metric = FID(num_features=1, feature_extractor=default_model)
111+
metric.attach(default_evaluator, "fid")
112+
y_true = torch.ones(10, 4)
113+
y_pred = torch.ones(10, 4)
114+
state = default_evaluator.run([[y_pred, y_true]])
115+
print(state.metrics["fid"])
116+
117+
.. testoutput::
102118
103-
y_pred, y = torch.rand(10, 3, 299, 299), torch.rand(10, 3, 299, 299)
104-
m = FID()
105-
m.update((y_pred, y))
106-
print(m.compute())
119+
0.0
107120
108121
.. versionadded:: 0.4.6
109122
"""

ignite/metrics/gan/inception_score.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,26 @@ class InceptionScore(_BaseInceptionMetric):
4848
The default Inception model requires the `torchvision` module to be installed.
4949
5050
Examples:
51+
5152
.. code-block:: python
5253
53-
from ignite.metric.gan import InceptionScore
54-
import torch
54+
metric = InceptionScore()
55+
metric.attach(default_evaluator, "is")
56+
y = torch.rand(10, 3, 299, 299)
57+
state = default_evaluator.run([y])
58+
print(state.metrics["is"])
59+
60+
.. testcode::
61+
62+
metric = InceptionScore(num_features=1, feature_extractor=default_model)
63+
metric.attach(default_evaluator, "is")
64+
y = torch.zeros(10, 4)
65+
state = default_evaluator.run([y])
66+
print(state.metrics["is"])
5567
56-
images = torch.rand(10, 3, 299, 299)
68+
.. testoutput::
5769
58-
m = InceptionScore()
59-
m.update(images)
60-
print(m.compute())
70+
1.0
6171
6272
.. versionadded:: 0.4.6
6373
"""

0 commit comments

Comments
 (0)