Skip to content

Commit 5521424

Browse files
977 fix dynunet task2 type issue (#978)
Fixes #977 . ### Description This PR adds the cast transform to prevent type issue on dynunet pipeline for decathlon task 2 dataset. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Notebook runs automatically `./runner [-p <regex_pattern>]` Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent a9eaf33 commit 5521424

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

modules/dynunet_pipeline/transforms.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,22 @@ def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_sampl
6767
RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5),
6868
RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5),
6969
RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5),
70+
CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
71+
EnsureTyped(keys=["image", "label"]),
72+
]
73+
elif mode == "validation":
74+
other_transforms = [
75+
CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
76+
EnsureTyped(keys=["image", "label"]),
7077
]
71-
72-
return Compose(load_transforms + sample_transforms + other_transforms)
7378
else:
74-
return Compose(load_transforms + sample_transforms)
79+
other_transforms = [
80+
CastToTyped(keys=["image"], dtype=(np.float32)),
81+
EnsureTyped(keys=["image"]),
82+
]
83+
84+
all_transforms = load_transforms + sample_transforms + other_transforms
85+
return Compose(all_transforms)
7586

7687
def resample_image(image, shape, anisotrophy_flag):
7788
resized_channels = []

0 commit comments

Comments
 (0)