Skip to content
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
bb6f4e8
Implement feature and fix a few bugs
sadra-barikbin May 14, 2022
65c8261
autopep8 fix
sadra-barikbin May 14, 2022
c960f7b
Fix MyPy issues
sadra-barikbin May 14, 2022
9f9d323
Remove unused imports
sadra-barikbin May 14, 2022
aa3fc9b
Fix flake8 issue and some bugs
sadra-barikbin May 14, 2022
288258c
Fix affected metrics
sadra-barikbin May 14, 2022
15e57a4
autopep8 fix
sadra-barikbin May 14, 2022
2b41176
Empty commit
sadra-barikbin May 14, 2022
db91bf2
Fix docstring
sadra-barikbin May 14, 2022
2a9d59f
Fix average parameter docstring
sadra-barikbin May 16, 2022
d7f7f2d
autopep8 fix
sadra-barikbin May 16, 2022
1c66ae0
Fix bug
sadra-barikbin May 16, 2022
1bd436e
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin May 16, 2022
bebacd4
Fix bug and classification_report
sadra-barikbin May 16, 2022
fdf1842
Fix bug in doctests and tests
sadra-barikbin May 16, 2022
711c86a
Fix bug in doctests and tests
sadra-barikbin May 16, 2022
b311002
Make recall like precision, undo classification_report changes
sadra-barikbin May 21, 2022
a3c7f5a
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin May 21, 2022
93728fc
Resolve mypy and flake issues
sadra-barikbin May 21, 2022
adcc1c5
Undo change
sadra-barikbin May 21, 2022
6de29b0
Merge branch 'master' into improve-precision-recall-metric-issue-2571
sadra-barikbin May 25, 2022
3ae3cdc
Add more description to docstrings
sadra-barikbin May 26, 2022
9cd7c4c
autopep8 fix
sadra-barikbin May 26, 2022
5341a8d
empty commit
sadra-barikbin May 26, 2022
c9df80c
Improve code
sadra-barikbin May 26, 2022
48e06c1
Add 'macro' option
sadra-barikbin Jun 1, 2022
c2637aa
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin Jun 1, 2022
cfbb04a
Merge branch 'master' into improve-precision-recall-metric-issue-2571
sadra-barikbin Jun 1, 2022
405634d
Add None option to average parameter
sadra-barikbin Jun 2, 2022
e206852
Merge branch 'improve-precision-recall-metric-issue-2571' of https://…
sadra-barikbin Jun 2, 2022
8a562ce
Fix affected tests
sadra-barikbin Jun 3, 2022
214088b
Fix affected doctests
sadra-barikbin Jun 3, 2022
d450903
Do some refactors and improvements
sadra-barikbin Jun 6, 2022
485e4e4
Reduce internal vars to three
sadra-barikbin Jun 7, 2022
5c813c8
Fix a few bugs and do a few improvements
sadra-barikbin Jun 7, 2022
da3853d
Fix bugs, tests and do a few refactors
sadra-barikbin Jun 9, 2022
4e007b0
Fix bug in doctests
sadra-barikbin Jun 9, 2022
300fbb9
Fix mypy issue
sadra-barikbin Jun 9, 2022
0a1bd52
A little improvement
sadra-barikbin Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions ignite/metrics/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from ignite.metrics.fbeta import Fbeta
from ignite.metrics.metric import Metric
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
Expand Down Expand Up @@ -85,14 +84,14 @@ def ClassificationReport(
[0, 0, 0],
[1, 0, 0],
[0, 1, 1],
]).unsqueeze(0)
])
y_pred = torch.tensor([
[1, 1, 0],
[1, 0, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
]).unsqueeze(0)
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["cr"].keys())
print(state.metrics["cr"]["0"])
Expand All @@ -119,25 +118,24 @@ def ClassificationReport(
averaged_fbeta = fbeta.mean()

def _wrapper(
recall_metric: Metric, precision_metric: Metric, f: Metric, a_recall: Metric, a_precision: Metric, a_f: Metric
re: torch.Tensor, pr: torch.Tensor, f: torch.Tensor, a_re: torch.Tensor, a_pr: torch.Tensor, a_f: torch.Tensor
) -> Union[Collection[str], Dict]:
p_tensor, r_tensor, f_tensor = precision_metric, recall_metric, f
if p_tensor.shape != r_tensor.shape:
if pr.shape != re.shape:
raise ValueError(
"Internal error: Precision and Recall have mismatched shapes: "
f"{p_tensor.shape} vs {r_tensor.shape}. Please, open an issue "
f"{pr.shape} vs {re.shape}. Please, open an issue "
"with a reference on this error. Thank you!"
)
dict_obj = {}
for idx, p_label in enumerate(p_tensor):
for idx, p_label in enumerate(pr):
dict_obj[_get_label_for_class(idx)] = {
"precision": p_label.item(),
"recall": r_tensor[idx].item(),
"f{0}-score".format(beta): f_tensor[idx].item(),
"recall": re[idx].item(),
"f{0}-score".format(beta): f[idx].item(),
}
dict_obj["macro avg"] = {
"precision": a_precision.item(),
"recall": a_recall.item(),
"precision": a_pr.item(),
"recall": a_re.item(),
"f{0}-score".format(beta): a_f.item(),
}
return dict_obj if output_dict else json.dumps(dict_obj)
Expand Down
8 changes: 4 additions & 4 deletions ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def Fbeta(

.. testcode:: 1

P = Precision(average=False)
R = Recall(average=False)
P = Precision(average=False)[1] # `[1]` is to select the value for class 1 that is TP/(TP+FP)
R = Recall(average=False)[1] # `[1]` is to select the value for class 1 that is TP/(TP+FN)
metric = Fbeta(beta=1.0, precision=P, recall=R)
metric.attach(default_evaluator, "f-beta")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
Expand Down Expand Up @@ -127,8 +127,8 @@ def thresholded_output_transform(output):
y_pred = torch.round(y_pred)
return y_pred, y

P = Precision(average=False, output_transform=thresholded_output_transform)
R = Recall(average=False, output_transform=thresholded_output_transform)
P = Precision(average=False, output_transform=thresholded_output_transform)[1]
R = Recall(average=False, output_transform=thresholded_output_transform)[1]
metric = Fbeta(beta=1.0, precision=P, recall=R)
metric.attach(default_evaluator, "f-beta")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/metrics_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class MetricsLambda(Metric):

.. testcode::

precision = Precision(average=False)
recall = Recall(average=False)
precision = Precision(average=False)[1] # `[1]` is to select the value for class 1 that is TP/(TP+FP)
recall = Recall(average=False)[1] # `[1]` is to select the value for class 1 that is TP/(TP+FN)

def Fbeta(r, p, beta):
return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item()
Expand Down
Loading