Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,45 @@ def __init__(
self.set_random_state(seed=get_seed())
self.overrides = overrides

# Automatically assign group ID to child transforms for inversion tracking
self._set_transform_groups()

def _set_transform_groups(self):
"""
Automatically set group IDs on child transforms for inversion tracking.
This allows Invertd to identify which transforms belong to this Compose instance.
Recursively sets groups on wrapped transforms (e.g., array transforms inside dictionary transforms).
"""
from monai.transforms.inverse import TraceableTransform

group_id = str(id(self))
visited = set() # Track visited objects to avoid infinite recursion

def set_group_recursive(obj, gid):
"""Recursively set group on transform and its wrapped transforms."""
# Avoid infinite recursion
obj_id = id(obj)
if obj_id in visited:
return
visited.add(obj_id)

if isinstance(obj, TraceableTransform):
obj._group = gid

# Handle wrapped transforms in dictionary transforms
# Check common attribute patterns for wrapped transforms
for attr_name in dir(obj):
# Skip magic methods and common non-transform attributes
if attr_name.startswith("__") or attr_name in ("transforms", "transform"):
continue
attr = getattr(obj, attr_name, None)
if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose):
# Recursively set group on nested transforms
set_group_recursive(attr, gid)

for transform in self.transforms:
set_group_recursive(transform, group_id)

@LazyTransform.lazy.setter # type: ignore
def lazy(self, val: bool):
self._lazy = val
Expand Down
15 changes: 14 additions & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _init_trace_threadlocal(self):
if not hasattr(self._tracing, "value"):
self._tracing.value = MONAIEnvVars.trace_transform() != "0"

# Initialize group identifier (set by Compose for automatic group tracking)
if not hasattr(self, "_group"):
self._group: str | None = None

def __getstate__(self):
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
Expand Down Expand Up @@ -119,13 +123,22 @@ def get_transform_info(self) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.
"""
# Ensure _group is initialized
self._init_trace_threadlocal()

vals = (
self.__class__.__name__,
id(self),
self.tracing,
self._do_transform if hasattr(self, "_do_transform") else True,
)
return dict(zip(self.transform_info_keys(), vals))
info = dict(zip(self.transform_info_keys(), vals))

# Add group if set (automatically set by Compose)
if self._group is not None:
info[TraceKeys.GROUP] = self._group

return info

def push_transform(self, data, *args, **kwargs):
"""
Expand Down
28 changes: 27 additions & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,29 @@ def __init__(
self.post_func = ensure_tuple_rep(post_func, len(self.keys))
self._totensor = ToTensor()

def _filter_transforms_by_group(self, all_transforms: list[dict]) -> list[dict]:
"""
Filter applied_operations to only include transforms from the target Compose instance.
Uses automatic group tracking where Compose assigns its ID to child transforms.
"""
from monai.utils import TraceKeys

# Get the group ID of the transform (Compose instance)
target_group = str(id(self.transform))

# Filter transforms that match the target group
filtered = []
for xform in all_transforms:
xform_group = xform.get(TraceKeys.GROUP)
if xform_group == target_group:
filtered.append(xform)

# If no transforms match (backward compatibility), return all transforms
if not filtered:
return all_transforms

return filtered

def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
d = dict(data)
for (
Expand Down Expand Up @@ -894,8 +917,11 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:

orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
if orig_key in d and isinstance(d[orig_key], MetaTensor):
transform_info = d[orig_key].applied_operations
all_transforms = d[orig_key].applied_operations
meta_info = d[orig_key].meta

# Automatically filter by Compose instance group ID
transform_info = self._filter_transforms_by_group(all_transforms)
else:
transform_info = d[InvertibleTransform.trace_key(orig_key)]
meta_info = d.get(orig_meta_key, {})
Expand Down
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class TraceKeys(StrEnum):
TRACING: str = "tracing"
STATUSES: str = "statuses"
LAZY: str = "lazy"
GROUP: str = "group"


class TraceStatusKeys(StrEnum):
Expand Down
218 changes: 218 additions & 0 deletions tests/transforms/inverse/test_invertd.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,224 @@ def test_invert(self):

set_determinism(seed=None)

def test_invertd_with_postprocessing_transforms(self):
"""Test that Invertd ignores postprocessing transforms using automatic group tracking.

This is a regression test for the issue where Invertd would fail when
postprocessing contains invertible transforms before Invertd is called.
The fix uses automatic group tracking where Compose assigns its ID to child transforms.
"""
from monai.data import MetaTensor, create_test_image_2d
from monai.transforms.utility.dictionary import Lambdad

img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
key = "image"

# Preprocessing pipeline
preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])])

# Postprocessing with Lambdad before Invertd
# Previously this would raise RuntimeError about transform ID mismatch
postprocessing = Compose(
[
Lambdad(key, func=lambda x: x), # Should be ignored during inversion
Invertd(key, transform=preprocessing, orig_keys=key),
]
)

# Apply transforms
item = {key: img}
pre = preprocessing(item)

# This should NOT raise an error (was failing before the fix)
try:
post = postprocessing(pre)
# If we get here, the bug is fixed
self.assertIsNotNone(post)
self.assertIn(key, post)
print("SUCCESS! Automatic group tracking fixed the bug.")
print(f" Preprocessing group ID: {id(preprocessing)}")
print(f" Postprocessing group ID: {id(postprocessing)}")
except RuntimeError as e:
if "getting the most recently applied invertible transform" in str(e):
self.fail(f"Invertd still has the postprocessing transform bug: {e}")

def test_invertd_multiple_pipelines(self):
"""Test that Invertd correctly handles multiple independent preprocessing pipelines."""
from monai.data import MetaTensor, create_test_image_2d
from monai.transforms.utility.dictionary import Lambdad

img1, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img1 = MetaTensor(img1, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
img2, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img2 = MetaTensor(img2, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})

# Two different preprocessing pipelines
preprocessing1 = Compose([EnsureChannelFirstd("image1"), Spacingd("image1", pixdim=[2.0, 2.0])])

preprocessing2 = Compose([EnsureChannelFirstd("image2"), Spacingd("image2", pixdim=[1.5, 1.5])])

# Postprocessing that inverts both
postprocessing = Compose(
[
Lambdad(["image1", "image2"], func=lambda x: x),
Invertd("image1", transform=preprocessing1, orig_keys="image1"),
Invertd("image2", transform=preprocessing2, orig_keys="image2"),
]
)

# Apply transforms
item = {"image1": img1, "image2": img2}
pre1 = preprocessing1(item)
pre2 = preprocessing2(pre1)

# Should not raise error - each Invertd should only invert its own pipeline
post = postprocessing(pre2)
self.assertIn("image1", post)
self.assertIn("image2", post)

def test_invertd_multiple_postprocessing_transforms(self):
"""Test Invertd with multiple invertible transforms in postprocessing before Invertd."""
from monai.data import MetaTensor, create_test_image_2d
from monai.transforms.utility.dictionary import Lambdad

img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
key = "image"

preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])])

# Multiple transforms in postprocessing before Invertd
postprocessing = Compose(
[
Lambdad(key, func=lambda x: x * 2),
Lambdad(key, func=lambda x: x + 1),
Lambdad(key, func=lambda x: x - 1),
Invertd(key, transform=preprocessing, orig_keys=key),
]
)

item = {key: img}
pre = preprocessing(item)
post = postprocessing(pre)

self.assertIsNotNone(post)
self.assertIn(key, post)

def test_invertd_group_isolation(self):
"""Test that groups correctly isolate transforms from different Compose instances."""
from monai.data import MetaTensor, create_test_image_2d

img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
key = "image"

# First preprocessing
preprocessing1 = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])])

# Second preprocessing (different pipeline)
preprocessing2 = Compose([Spacingd(key, pixdim=[1.5, 1.5])])

item = {key: img}
pre1 = preprocessing1(item)

# Verify group IDs are in applied_operations
self.assertTrue(len(pre1[key].applied_operations) > 0)
group1 = pre1[key].applied_operations[0].get("group")
self.assertIsNotNone(group1)
self.assertEqual(group1, str(id(preprocessing1)))

# Apply second preprocessing
pre2 = preprocessing2(pre1)

# Should have operations from both pipelines with different groups
groups = [op.get("group") for op in pre2[key].applied_operations]
self.assertIn(str(id(preprocessing1)), groups)
self.assertIn(str(id(preprocessing2)), groups)

# Inverting preprocessing1 should only invert its transforms
inverter = Invertd(key, transform=preprocessing1, orig_keys=key)
inverted = inverter(pre2)
self.assertIsNotNone(inverted)

def test_compose_inverse_with_groups(self):
"""Test that Compose.inverse() works correctly with automatic group tracking."""
from monai.data import MetaTensor, create_test_image_2d

img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
key = "image"

# Create a preprocessing pipeline
preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])])

# Apply preprocessing
item = {key: img}
pre = preprocessing(item)

# Call inverse() directly on the Compose object
inverted = preprocessing.inverse(pre)

# Should successfully invert
self.assertIsNotNone(inverted)
self.assertIn(key, inverted)
# Shape should be restored after inversion
self.assertEqual(inverted[key].shape[1:], img.shape)

def test_compose_inverse_with_postprocessing_groups(self):
"""Test Compose.inverse() when data has been through multiple pipelines with different groups."""
from monai.data import MetaTensor, create_test_image_2d
from monai.transforms.utility.dictionary import Lambdad

img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
key = "image"

# Preprocessing pipeline
preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])])

# Postprocessing pipeline (different group)
postprocessing = Compose([Lambdad(key, func=lambda x: x * 2)])

# Apply both pipelines
item = {key: img}
pre = preprocessing(item)
post = postprocessing(pre)

# Now call inverse() directly on preprocessing
# This tests that inverse() can handle data that has transforms from multiple groups
# This WILL fail because applied_operations contains postprocessing transforms
# and inverse() doesn't do group filtering (only Invertd does)
with self.assertRaises(RuntimeError):
preprocessing.inverse(post)

def test_mixed_invertd_and_compose_inverse(self):
"""Test mixing Invertd (with group filtering) and Compose.inverse() (without filtering)."""
from monai.data import MetaTensor, create_test_image_2d

img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
key = "image"

# First pipeline
pipeline1 = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])])

# Apply first pipeline
item = {key: img}
result1 = pipeline1(item)

# Use Compose.inverse() directly - should work fine
inverted1 = pipeline1.inverse(result1)
self.assertIsNotNone(inverted1)
self.assertEqual(inverted1[key].shape[1:], img.shape)

# Now apply pipeline again and use Invertd
result2 = pipeline1(item)
inverter = Invertd(key, transform=pipeline1, orig_keys=key)
inverted2 = inverter(result2)
self.assertIsNotNone(inverted2)


if __name__ == "__main__":
unittest.main()
Loading