Skip to content

Commit b6bdedf

Browse files
sayantan1410actions-user
authored andcommitted
autopep8 fix
1 parent a6a8996 commit b6bdedf

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

ignite/contrib/metrics/precision_recall_curve.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Any, Callable, Tuple,cast
1+
from typing import Any, Callable, cast, Tuple
22

33
import torch
44

55
import ignite.distributed as idist
6-
from ignite.metrics import EpochMetric
76
from ignite.exceptions import NotComputableError
7+
from ignite.metrics import EpochMetric
88

99

1010
def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]:
@@ -63,6 +63,7 @@ def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: b
6363
super(PrecisionRecallCurve, self).__init__(
6464
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
6565
)
66+
6667
def compute(self) -> float:
6768
if len(self._predictions) < 1 or len(self._targets) < 1:
6869
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")
@@ -77,18 +78,17 @@ def compute(self) -> float:
7778
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
7879
self._is_reduced = True
7980

80-
precision = torch.zeros(1,len(self._predictions))
81-
recall = torch.zeros(1,len(self._predictions))
82-
thresholds = torch.zeros(1,len(self._predictions)-1)
81+
precision = torch.zeros(1, len(self._predictions))
82+
recall = torch.zeros(1, len(self._predictions))
83+
thresholds = torch.zeros(1, len(self._predictions) - 1)
8384
if idist.get_rank() == 0:
8485
# Run compute_fn on zero rank only
85-
precision,recall,thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
86-
86+
precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
8787

8888
if ws > 1:
8989
# broadcast result to all processes
9090
precision = cast(float, idist.broadcast(precision, src=0))
9191
recall = cast(float, idist.broadcast(recall, src=0))
9292
thresholds = cast(float, idist.broadcast(thresholds, src=0))
9393

94-
return precision,recall,thresholds
94+
return precision, recall, thresholds

tests/ignite/contrib/metrics/test_precision_recall_curve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_check_compute_fn():
126126
em = PrecisionRecallCurve(check_compute_fn=False)
127127
em.update(output)
128128

129+
129130
def _test_distrib_binary_input(device):
130131

131132
rank = idist.get_rank()

0 commit comments

Comments
 (0)