diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index e57156467..276dcd167 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -540,13 +540,16 @@ def _internal_predict_sliding_window_return_logits(self, do_on_device: bool = True, ): results_device = self.device if do_on_device else torch.device('cpu') + empty_cache(self.device) # move data to device - if self.verbose: print(f'move image to device {results_device}') - data = data.to(self.device) + if self.verbose: + print(f'move image to device {results_device}') + data = data.to(results_device) # preallocate arrays - if self.verbose: print(f'preallocating results arrays on device {results_device}') + if self.verbose: + print(f'preallocating results arrays on device {results_device}') predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=results_device) @@ -555,7 +558,6 @@ def _internal_predict_sliding_window_return_logits(self, gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=results_device) - empty_cache(self.device) if self.verbose: print('running prediction') if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps')