1
- from typing import Any , Callable , Tuple , cast
1
+ from typing import Any , Callable , cast , Tuple
2
2
3
3
import torch
4
4
5
5
import ignite .distributed as idist
6
- from ignite .metrics import EpochMetric
7
6
from ignite .exceptions import NotComputableError
7
+ from ignite .metrics import EpochMetric
8
8
9
9
10
10
def precision_recall_curve_compute_fn (y_preds : torch .Tensor , y_targets : torch .Tensor ) -> Tuple [Any , Any , Any ]:
@@ -63,6 +63,7 @@ def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: b
63
63
super (PrecisionRecallCurve , self ).__init__ (
64
64
precision_recall_curve_compute_fn , output_transform = output_transform , check_compute_fn = check_compute_fn
65
65
)
66
+
66
67
def compute (self ) -> float :
67
68
if len (self ._predictions ) < 1 or len (self ._targets ) < 1 :
68
69
raise NotComputableError ("EpochMetric must have at least one example before it can be computed." )
@@ -77,18 +78,17 @@ def compute(self) -> float:
77
78
_target_tensor = cast (torch .Tensor , idist .all_gather (_target_tensor ))
78
79
self ._is_reduced = True
79
80
80
- precision = torch .zeros (1 ,len (self ._predictions ))
81
- recall = torch .zeros (1 ,len (self ._predictions ))
82
- thresholds = torch .zeros (1 ,len (self ._predictions )- 1 )
81
+ precision = torch .zeros (1 , len (self ._predictions ))
82
+ recall = torch .zeros (1 , len (self ._predictions ))
83
+ thresholds = torch .zeros (1 , len (self ._predictions ) - 1 )
83
84
if idist .get_rank () == 0 :
84
85
# Run compute_fn on zero rank only
85
- precision ,recall ,thresholds = self .compute_fn (_prediction_tensor , _target_tensor )
86
-
86
+ precision , recall , thresholds = self .compute_fn (_prediction_tensor , _target_tensor )
87
87
88
88
if ws > 1 :
89
89
# broadcast result to all processes
90
90
precision = cast (float , idist .broadcast (precision , src = 0 ))
91
91
recall = cast (float , idist .broadcast (recall , src = 0 ))
92
92
thresholds = cast (float , idist .broadcast (thresholds , src = 0 ))
93
93
94
- return precision ,recall ,thresholds
94
+ return precision , recall , thresholds
0 commit comments