Skip to content

Commit 4c2e341

Browse files
committed
fixes #5509
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 696e411 commit 4c2e341

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

monai/data/meta_tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,11 @@ def ensure_torch_and_prune_meta(
523523
By default, a `MetaTensor` is returned.
524524
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
525525
"""
526-
img = convert_to_tensor(im) # potentially ascontiguousarray
526+
tracking_meta = get_track_meta() and meta is not None
527+
img = convert_to_tensor(im, track_meta=tracking_meta) # potentially ascontiguousarray
527528

528529
# if not tracking metadata, return `torch.Tensor`
529-
if not get_track_meta() or meta is None:
530+
if not tracking_meta:
530531
return img
531532

532533
# remove any superfluous metadata.
@@ -540,7 +541,7 @@ def ensure_torch_and_prune_meta(
540541
meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)
541542

542543
# return the `MetaTensor`
543-
return MetaTensor(img, meta=meta)
544+
return img.copy_meta_from(meta)
544545

545546
def __repr__(self):
546547
"""

monai/transforms/inverse.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def track_transform_tensor(
170170

171171
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
172172
if isinstance(data, Mapping):
173+
if not isinstance(data, dict):
174+
data = dict(data)
173175
data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t
174176
return data
175177
return out_obj # return with data_t as tensor if get_track_meta() is False
@@ -202,15 +204,14 @@ def track_transform_tensor(
202204
else:
203205
out_obj.push_applied_operation(info)
204206
if isinstance(data, Mapping):
207+
if not isinstance(data, dict):
208+
data = dict(data)
205209
if isinstance(data_t, MetaTensor):
206210
data[key] = data_t.copy_meta_from(out_obj)
207211
else:
208-
# If this is the first, create list
209212
x_k = TraceableTransform.trace_key(key)
210213
if x_k not in data:
211-
if not isinstance(data, dict):
212-
data = dict(data)
213-
data[x_k] = []
214+
data[x_k] = [] # If this is the first, create list
214215
data[x_k].append(info)
215216
return data
216217
return out_obj

0 commit comments

Comments
 (0)