Skip to content

Commit c239979

Browse files
Fix auto3dseg swin template infer bug and sigmoid bug (#256)
* Update algo docs and remove disclaimer Signed-off-by: heyufan1995 <heyufan1995@gmail.com> * fix infer cuda error and sigmoid Signed-off-by: heyufan1995 <heyufan1995@gmail.com> * [MONAI] code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> --------- Signed-off-by: heyufan1995 <heyufan1995@gmail.com> Signed-off-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 58af562 commit c239979

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

auto3dseg/algorithm_templates/segresnet2d/scripts/segmenter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ def get_crop_transforms(self):
324324
max_samples_per_class = None
325325
indices_key = None
326326

327-
328327
if cache_class_indices:
329328
ts.append(
330329
ClassesToIndicesd(

auto3dseg/algorithm_templates/swinunetr/scripts/infer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Optional, Sequence, Union
1616

1717
import torch
18+
import torch.distributed as dist
1819

1920
import monai
2021
from monai import transforms
@@ -85,7 +86,11 @@ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **ov
8586
infer_ds = monai.data.Dataset(data=self.infer_files, transform=self.infer_transforms)
8687
self.infer_loader = ThreadDataLoader(infer_ds, num_workers=8, batch_size=1, shuffle=False)
8788

88-
self.device = torch.device("cuda:0")
89+
try:
90+
device = f"cuda:{dist.get_rank()}"
91+
except BaseException:
92+
device = f"cuda:0"
93+
self.device = device
8994

9095
self.model = parser.get_parsed_content("network")
9196
self.model = self.model.to(self.device)
@@ -141,6 +146,7 @@ def infer(self, image_file, save_mask=False):
141146
device_list_output = [self.device, "cpu", "cpu"]
142147
for _device_in, _device_out in zip(device_list_input, device_list_output):
143148
try:
149+
logger.debug(f"Working on {image_file} on device {_device_in}/{_device_out} in/out.")
144150
with torch.cuda.amp.autocast(enabled=self.amp):
145151
batch_data["pred"] = sliding_window_inference(
146152
inputs=batch_data["image"].to(_device_in),
@@ -165,6 +171,7 @@ def infer(self, image_file, save_mask=False):
165171
break
166172
if not finished:
167173
raise RuntimeError("Infer not finished due to OOM.")
174+
logger.debug(f"{image_file} fininshed.")
168175
return batch_data[0]["pred"]
169176

170177
@torch.no_grad()

auto3dseg/algorithm_templates/swinunetr/scripts/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
366366
post_transforms += [transforms.AsDiscreted(keys="pred", argmax=True)]
367367
else:
368368
post_transforms += [
369-
transforms.Activations(sigmoid=True),
369+
transforms.Activationsd(keys="pred", sigmoid=True),
370370
transforms.AsDiscreted(keys="pred", threshold=0.5),
371371
]
372372
post_transforms = transforms.Compose(post_transforms)

auto3dseg/algorithm_templates/swinunetr/scripts/validate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
124124
if softmax:
125125
post_transforms += [transforms.AsDiscreted(keys="pred", argmax=True)]
126126
else:
127-
post_transforms += [transforms.Activations(sigmoid=True), transforms.AsDiscreted(keys="pred", threshold=0.5)]
127+
post_transforms += [
128+
transforms.Activationsd(keys="pred", sigmoid=True),
129+
transforms.AsDiscreted(keys="pred", threshold=0.5),
130+
]
128131

129132
if save_mask:
130133
post_transforms += [

0 commit comments

Comments
 (0)