Skip to content

Commit 6f5005f

Browse files
authored
Fix Spacing (#6912)
Fixes #6911. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 2862f53 commit 6f5005f

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

monai/transforms/spatial/dictionary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,8 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No
517517
output_spatial_shape=output_shape_k if should_match else None,
518518
lazy=lazy_,
519519
)
520-
output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:]
520+
if output_shape_k is None:
521+
output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:]
521522
return d
522523

523524
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:

tests/test_spacingd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,20 @@
8383
*device,
8484
)
8585
)
86+
TESTS.append(
87+
(
88+
"interp sep",
89+
{
90+
"image": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),
91+
"seg1": MetaTensor(torch.ones((2, 1, 10)), affine=torch.diag(torch.tensor([2, 2, 2, 1]))),
92+
"seg2": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),
93+
},
94+
dict(keys=("image", "seg1", "seg2"), mode=("bilinear", "nearest", "nearest"), pixdim=(1, 1, 1)),
95+
(2, 1, 10),
96+
torch.as_tensor(np.diag((1, 1, 1, 1))),
97+
*device,
98+
)
99+
)
86100

87101
TESTS_TORCH = []
88102
for track_meta in (False, True):

0 commit comments

Comments
 (0)