Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Wrapping/Generators/Python/itk/support/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
# ==========================================================================*/

import importlib
from importlib.metadata import metadata
import os
import re
import functools
Expand All @@ -24,17 +26,17 @@

_HAVE_XARRAY = False
try:
import xarray as xr
metadata('xarray')

_HAVE_XARRAY = True
except ImportError:
pass
_HAVE_TORCH = False
try:
import torch
metadata('torch')

_HAVE_TORCH = True
except ImportError:
except importlib.metadata.PackageNotFoundError:
pass


Expand Down Expand Up @@ -84,6 +86,10 @@ def accept_array_like_xarray_torch(image_filter):
If a xarray DataArray is passed as an input, output itk.Image's are converted to xarray.DataArray's."""
import numpy as np
import itk
if _HAVE_XARRAY:
import xarray as xr
if _HAVE_TORCH:
import torch

@functools.wraps(image_filter)
def image_filter_wrapper(*args, **kwargs):
Expand Down
16 changes: 9 additions & 7 deletions Wrapping/Generators/Python/itk/support/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
# ==========================================================================*/

import importlib
from importlib.metadata import metadata
from typing import Union, Optional, Tuple, TYPE_CHECKING
import os

Expand All @@ -26,17 +28,17 @@

_HAVE_XARRAY = False
try:
import xarray as xr
metadata('xarray')

_HAVE_XARRAY = True
except ImportError:
except importlib.metadata.PackageNotFoundError:
pass
_HAVE_TORCH = False
try:
import torch
metadata('torch')

_HAVE_TORCH = True
except ImportError:
except importlib.metadata.PackageNotFoundError:
pass

# noinspection PyPep8Naming
Expand Down Expand Up @@ -218,11 +220,11 @@ def initialize_c_types_once() -> (
ImageOrImageSource = Union[ImageBase, ImageSource]
# Can be coerced into an itk.ImageBase
if _HAVE_XARRAY and _HAVE_TORCH:
ImageLike = Union[ImageBase, ArrayLike, xr.DataArray, torch.Tensor]
ImageLike = Union[ImageBase, ArrayLike, "xr.DataArray", "torch.Tensor"]
elif _HAVE_XARRAY:
ImageLike = Union[ImageBase, ArrayLike, xr.DataArray]
ImageLike = Union[ImageBase, ArrayLike, "xr.DataArray"]
elif _HAVE_TORCH:
ImageLike = Union[ImageBase, ArrayLike, torch.Tensor]
ImageLike = Union[ImageBase, ArrayLike, "torch.Tensor"]
else:
ImageLike = Union[ImageBase, ArrayLike]

Expand Down