From 4cb7bb2da8c0dc2cbb026a47be21ea4598959445 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:19:20 +0100 Subject: [PATCH] 7088 compatible metric util torch inf nan (#7080) Fixes #7088 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li --- monai/metrics/hausdorff_distance.py | 2 +- monai/metrics/surface_dice.py | 2 +- monai/metrics/surface_distance.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 440b2b9518..d9bbf17db3 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -208,7 +208,7 @@ def _compute_percentile_hausdorff_distance( # for both pred and gt do not have foreground if surface_distance.shape == (0,): - return torch.tensor(torch.nan, dtype=torch.float, device=surface_distance.device) + return torch.tensor(np.nan, dtype=torch.float, device=surface_distance.device) if not percentile: return surface_distance.max() diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 1b19a26216..f8c402a756 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -275,7 +275,7 @@ def compute_surface_dice( boundary_correct = gt_true + pred_true if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation - nsd[b, c] = torch.nan + nsd[b, c] = torch.tensor(np.nan) else: nsd[b, c] = boundary_correct / boundary_complete diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index bdc4395562..7ce632c588 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -181,6 +181,6 @@ def compute_average_surface_distance( class_index=c, ) surface_distance = torch.cat(distances) - asd[b, c] = torch.nan if surface_distance.shape == (0,) else surface_distance.mean() + asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]