Skip to content

Commit

Permalink
7088 compatible metric util torch inf nan (#7080)
Browse files Browse the repository at this point in the history
Fixes #7088


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Oct 5, 2023
1 parent 141bcf0 commit 4cb7bb2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _compute_percentile_hausdorff_distance(

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
return torch.tensor(torch.nan, dtype=torch.float, device=surface_distance.device)
return torch.tensor(np.nan, dtype=torch.float, device=surface_distance.device)

if not percentile:
return surface_distance.max()
Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def compute_surface_dice(
boundary_correct = gt_true + pred_true
if boundary_complete == 0:
# the class is neither present in the prediction, nor in the reference segmentation
nsd[b, c] = torch.nan
nsd[b, c] = torch.tensor(np.nan)
else:
nsd[b, c] = boundary_correct / boundary_complete

Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,6 @@ def compute_average_surface_distance(
class_index=c,
)
surface_distance = torch.cat(distances)
asd[b, c] = torch.nan if surface_distance.shape == (0,) else surface_distance.mean()
asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean()

return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]

0 comments on commit 4cb7bb2

Please sign in to comment.