Skip to content

Commit 6a0e1b0

Browse files
heyufan1995pre-commit-ci[bot]KumoLiuyiheng-wang-nv
authored
Fix vista3d transpose bug (#8059)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 <heyufan1995@gmail.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com>
1 parent 7219ee7 commit 6a0e1b0

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

monai/apps/vista3d/inferer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def point_based_window_inferer(
100100
point_labels=point_labels,
101101
class_vector=class_vector,
102102
prompt_class=prompt_class,
103-
patch_coords=unravel_slice,
103+
patch_coords=[unravel_slice],
104104
prev_mask=prev_mask,
105105
**kwargs,
106106
)

monai/networks/nets/vista3d.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
336336
def forward(
337337
self,
338338
input_images: torch.Tensor,
339-
patch_coords: Sequence[slice] | None = None,
339+
patch_coords: list[Sequence[slice]] | None = None,
340340
point_coords: torch.Tensor | None = None,
341341
point_labels: torch.Tensor | None = None,
342342
class_vector: torch.Tensor | None = None,
@@ -364,8 +364,12 @@ def forward(
364364
the points are for zero-shot or supported class. When class_vector and point_coords are both
365365
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
366366
will be considered novel class.
367-
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
368-
This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
367+
patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
368+
inference. This value is passed from sliding_window_inferer.
369+
This is an indicator for training phase or validation phase.
370+
Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
371+
coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
372+
functions using patch_coords will by default use patch_coords[0].
369373
labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
370374
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
371375
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
@@ -395,14 +399,14 @@ def forward(
395399
if val_point_sampler is None:
396400
# TODO: think about how to refactor this part.
397401
val_point_sampler = self.sample_points_patch_val
398-
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
402+
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
399403
if prompt_class[0].item() == 0: # type: ignore
400404
point_labels[0] = -1 # type: ignore
401405
labels, prev_mask = None, None
402406
elif point_coords is not None:
403407
# If not performing patch-based point only validation, use user provided click points for inference.
404408
# the point clicks is in original image space, convert it to current patch-coordinate space.
405-
point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
409+
point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore
406410

407411
if point_coords is not None and point_labels is not None:
408412
# remove points that used for padding purposes (point_label = -1)
@@ -455,7 +459,7 @@ def forward(
455459
logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
456460
if prev_mask is not None and patch_coords is not None:
457461
logits = self.connected_components_combine(
458-
prev_mask[patch_coords].transpose(1, 0).to(logits.device),
462+
prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
459463
logits[mapping_index],
460464
point_coords, # type: ignore
461465
point_labels, # type: ignore

0 commit comments

Comments
 (0)