Skip to content

Commit

Permalink
fix minor bug that could lead to OOM in inference with large images
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Dec 18, 2023
1 parent b00b41f commit 6309155
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,16 @@ def _internal_predict_sliding_window_return_logits(self,
do_on_device: bool = True,
):
results_device = self.device if do_on_device else torch.device('cpu')
empty_cache(self.device)

# move data to device
if self.verbose: print(f'move image to device {results_device}')
data = data.to(self.device)
if self.verbose:
print(f'move image to device {results_device}')
data = data.to(results_device)

# preallocate arrays
if self.verbose: print(f'preallocating results arrays on device {results_device}')
if self.verbose:
print(f'preallocating results arrays on device {results_device}')
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
Expand All @@ -555,7 +558,6 @@ def _internal_predict_sliding_window_return_logits(self,
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
empty_cache(self.device)

if self.verbose: print('running prediction')
if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps')
Expand Down

0 comments on commit 6309155

Please sign in to comment.