Skip to content

Commit

Permalink
Make RetinaNet throw errors for NaN only when training (#6479)
Browse files Browse the repository at this point in the history
Fixes #6478 .

### Description

I assume that the `amp=True` setting affects the numeric stability in
evaluators and causes the issues in #6478 ,

The fix in this PR aims to continue the training by not raising errors
during the evaluation.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
  • Loading branch information
mingxin-zheng authored May 5, 2023
1 parent 4a0afc8 commit d688769
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
15 changes: 12 additions & 3 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,10 @@ def _reshape_maps(self, result_maps: list[Tensor]) -> Tensor:
reshaped_result_map = reshaped_result_map.reshape(batch_size, -1, num_channel)

if torch.isnan(reshaped_result_map).any() or torch.isinf(reshaped_result_map).any():
raise ValueError("Concatenated result is NaN or Inf.")
if torch.is_grad_enabled():
raise ValueError("Concatenated result is NaN or Inf.")
else:
warnings.warn("Concatenated result is NaN or Inf.")

all_reshaped_result_map.append(reshaped_result_map)

Expand Down Expand Up @@ -893,7 +896,10 @@ def get_cls_train_sample_per_image(
"""

if torch.isnan(cls_logits_per_image).any() or torch.isinf(cls_logits_per_image).any():
raise ValueError("NaN or Inf in predicted classification logits.")
if torch.is_grad_enabled():
raise ValueError("NaN or Inf in predicted classification logits.")
else:
warnings.warn("NaN or Inf in predicted classification logits.")

foreground_idxs_per_image = matched_idxs_per_image >= 0

Expand Down Expand Up @@ -973,7 +979,10 @@ def get_box_train_sample_per_image(
"""

if torch.isnan(box_regression_per_image).any() or torch.isinf(box_regression_per_image).any():
raise ValueError("NaN or Inf in predicted box regression.")
if torch.is_grad_enabled():
raise ValueError("NaN or Inf in predicted box regression.")
else:
warnings.warn("NaN or Inf in predicted box regression.")

foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_gt_box = targets_per_image[self.target_box_key].shape[0]
Expand Down
11 changes: 9 additions & 2 deletions monai/apps/detection/networks/retinanet_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from __future__ import annotations

import math
import warnings
from collections.abc import Callable, Sequence
from typing import Any, Dict

Expand Down Expand Up @@ -125,7 +126,10 @@ def forward(self, x: list[Tensor]) -> list[Tensor]:
cls_logits_maps.append(cls_logits)

if torch.isnan(cls_logits).any() or torch.isinf(cls_logits).any():
raise ValueError("cls_logits is NaN or Inf.")
if torch.is_grad_enabled():
raise ValueError("cls_logits is NaN or Inf.")
else:
warnings.warn("cls_logits is NaN or Inf.")

return cls_logits_maps

Expand Down Expand Up @@ -194,7 +198,10 @@ def forward(self, x: list[Tensor]) -> list[Tensor]:
box_regression_maps.append(box_regression)

if torch.isnan(box_regression).any() or torch.isinf(box_regression).any():
raise ValueError("box_regression is NaN or Inf.")
if torch.is_grad_enabled():
raise ValueError("box_regression is NaN or Inf.")
else:
warnings.warn("box_regression is NaN or Inf.")

return box_regression_maps

Expand Down

0 comments on commit d688769

Please sign in to comment.