Skip to content

Commit 6b9e4d6

Browse files
committed
Fixes
1 parent 11be4cb commit 6b9e4d6

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

monai/transforms/inverse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from __future__ import annotations
1313

14+
import threading
1415
import warnings
1516
from collections.abc import Hashable, Mapping
1617
from contextlib import contextmanager
1718
from typing import Any
18-
import threading
1919

2020
import torch
2121

@@ -76,7 +76,7 @@ def _init_trace_threadlocal(self):
7676
if not hasattr(self, "_tracing"):
7777
self._tracing = threading.local()
7878

79-
# This is True while the above initialising _tracing is False when this is
79+
# This is True while the above initialising _tracing is False when this is
8080
# called from a different thread than the one initialising _tracing.
8181
if not hasattr(self._tracing, "value"):
8282
self._tracing.value = MONAIEnvVars.trace_transform() != "0"
@@ -87,7 +87,7 @@ def tracing(self) -> bool:
8787
Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`.
8888
"""
8989
self._init_trace_threadlocal()
90-
return self._tracing.value
90+
return bool(self._tracing.value)
9191

9292
@tracing.setter
9393
def tracing(self, val: bool):
@@ -338,18 +338,18 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
338338

339339
# Find the last transform whose name matches that of this class, this allows Invertd to ignore applied
340340
# operations added by transforms it is not trying to invert, ie. those added in postprocessing.
341-
idx=-1
341+
idx = -1
342342
for i in reversed(range(len(all_transforms))):
343343
xform_name = all_transforms[i].get(TraceKeys.CLASS_NAME, "")
344344
if xform_name == self.__class__.__name__:
345-
idx=i # if nothing found, idx remains -1 so replicating previous behaviour
345+
idx = i # if nothing found, idx remains -1 so replicating previous behaviour
346346
break
347347

348348
if not all_transforms:
349349
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")
350350

351351
if check:
352-
if not (-len(all_transforms)<=idx<len(all_transforms)):
352+
if not (-len(all_transforms) <= idx < len(all_transforms)):
353353
raise IndexError(f"Index '{idx}' not valid for list of applied operations '{all_transforms}'")
354354

355355
self.check_transforms_match(all_transforms[idx])

tests/transforms/inverse/test_inverse_dict.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111

1212
from __future__ import annotations
1313

14-
from itertools import product
1514
import unittest
15+
from itertools import product
1616

1717
import torch
1818
from parameterized import parameterized
1919

20-
from monai.data import MetaTensor, create_test_image_2d, Dataset, ThreadDataLoader, DataLoader
20+
from monai.data import DataLoader, Dataset, MetaTensor, ThreadDataLoader, create_test_image_2d
2121
from monai.engines.evaluator import SupervisedEvaluator
2222
from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd
2323
from monai.transforms.utility.dictionary import Lambdad
2424
from monai.utils.enums import CommonKeys
25-
2625
from tests.test_utils import TEST_DEVICES
2726

2827

@@ -48,12 +47,12 @@ def setUp(self):
4847
@parameterized.expand(TEST_DEVICES)
4948
def test_simple_processing(self, device):
5049
"""
51-
Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly.
52-
53-
This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which
54-
returns it to the original shape using Invertd. This tests that the shape of the output is the same as the
55-
original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing
56-
sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in
50+
Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly.
51+
52+
This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which
53+
returns it to the original shape using Invertd. This tests that the shape of the output is the same as the
54+
original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing
55+
sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in
5756
`self.postprocessing`.
5857
"""
5958

@@ -77,8 +76,8 @@ def test_simple_processing(self, device):
7776
@parameterized.expand(product(sum(TEST_DEVICES, []), [True, False]))
7877
def test_workflow(self, device, use_threads):
7978
"""
80-
This tests the interaction between pre and postprocesing transform sequences being executed in parallel.
81-
79+
This tests the interaction between pre and postprocesing transform sequences being executed in parallel.
80+
8281
When the `ThreadDataLoader` is used to load batches, this is done in parallel at times with the execution of
8382
the post-process transform sequence. Previously this encountered a race condition at times because the
8483
`TraceableTransform.tracing` variables of transforms was being toggled in different threads, so at times a

0 commit comments

Comments
 (0)