Skip to content

Commit

Permalink
6794 optionally pass coordinates to predictor during sliding window (#…
Browse files Browse the repository at this point in the history
…6795)

Fixes #6794 


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Sep 21, 2023
1 parent 2a0ed97 commit f3200b9
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions monai/apps/pathology/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __call__(
self.process_output,
self.buffer_steps,
self.buffer_dim,
False,
*args,
**kwargs,
)
Expand Down
5 changes: 5 additions & 0 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ class SlidingWindowInferer(Inferer):
(i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
buffer_dim: the spatial dimension along which the buffers are created.
0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
with_coord: whether to pass the window coordinates to ``network``. Defaults to False.
If True, the ``network``'s 2nd input argument should accept the window coordinates.
Note:
``sw_batch_size`` denotes the max number of windows per network inference iteration,
Expand All @@ -449,6 +451,7 @@ def __init__(
cpu_thresh: int | None = None,
buffer_steps: int | None = None,
buffer_dim: int = -1,
with_coord: bool = False,
) -> None:
super().__init__()
self.roi_size = roi_size
Expand All @@ -464,6 +467,7 @@ def __init__(
self.cpu_thresh = cpu_thresh
self.buffer_steps = buffer_steps
self.buffer_dim = buffer_dim
self.with_coord = with_coord

# compute_importance_map takes long time when computing on cpu. We thus
# compute it once if it's static and then save it for future usage
Expand Down Expand Up @@ -525,6 +529,7 @@ def __call__(
None,
buffer_steps,
buffer_dim,
self.with_coord,
*args,
**kwargs,
)
Expand Down
8 changes: 7 additions & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def sliding_window_inference(
process_fn: Callable | None = None,
buffer_steps: int | None = None,
buffer_dim: int = -1,
with_coord: bool = False,
*args: Any,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
Expand Down Expand Up @@ -125,6 +126,8 @@ def sliding_window_inference(
(i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
buffer_dim: the spatial dimension along which the buffers are created.
0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
with_coord: whether to pass the window coordinates to ``predictor``. Default is False.
If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``.
args: optional args to be passed to ``predictor``.
kwargs: optional keyword args to be passed to ``predictor``.
Expand Down Expand Up @@ -220,7 +223,10 @@ def sliding_window_inference(
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
if with_coord:
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
else:
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch

# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
Expand Down
1 change: 1 addition & 0 deletions tests/test_sliding_window_hovernet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def compute(data, test1, test2):
None,
None,
0,
False,
t1,
test2=t2,
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def compute(data, test1, test2):
None,
None,
0,
False,
t1,
test2=t2,
)
Expand Down

0 comments on commit f3200b9

Please sign in to comment.