Skip to content

Commit b097e16

Browse files
KumoLiuericspod
authored andcommitted
Add arm support (Project-MONAI#7500)
Fixes # . ### Description Add arm support ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
1 parent fcff102 commit b097e16

File tree

8 files changed

+41
-14
lines changed

8 files changed

+41
-14
lines changed

Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ FROM ${PYTORCH_IMAGE}
1616

1717
LABEL maintainer="monai.contact@gmail.com"
1818

19+
# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
20+
WORKDIR /opt
21+
RUN git clone --recursive https://github.com/zarr-developers/numcodecs.git && pip wheel numcodecs
22+
1923
WORKDIR /opt/monai
2024

2125
# install full deps

requirements-dev.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ mypy>=1.5.0
2626
ninja
2727
torchvision
2828
psutil
29-
cucim>=23.2.0; platform_system == "Linux"
29+
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
3030
openslide-python
3131
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
3232
tifffile; platform_system == "Linux" or platform_system == "Darwin"
@@ -46,7 +46,7 @@ pynrrd
4646
pre-commit
4747
pydicom
4848
h5py
49-
nni; platform_system == "Linux"
49+
nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
5050
optuna
5151
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
5252
onnx>=1.13.0

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ all =
5959
tqdm>=4.47.0
6060
lmdb
6161
psutil
62-
cucim>=23.2.0
62+
cucim-cu12; python_version >= '3.9' and python_version <= '3.10'
6363
openslide-python
6464
tifffile
6565
imagecodecs
@@ -111,7 +111,7 @@ lmdb =
111111
psutil =
112112
psutil
113113
cucim =
114-
cucim>=23.2.0
114+
cucim-cu12
115115
openslide =
116116
openslide-python
117117
tifffile =

tests/test_convert_to_onnx.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import itertools
15+
import platform
1516
import unittest
1617

1718
import torch
@@ -29,6 +30,12 @@
2930
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
3031
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))
3132

33+
ON_AARCH64 = platform.machine() == "aarch64"
34+
if ON_AARCH64:
35+
rtol, atol = 1e-1, 1e-2
36+
else:
37+
rtol, atol = 1e-3, 1e-4
38+
3239
onnx, _ = optional_import("onnx")
3340

3441

@@ -56,8 +63,8 @@ def test_unet(self, device, use_trace, use_ort):
5663
device=device,
5764
use_ort=use_ort,
5865
use_trace=use_trace,
59-
rtol=1e-3,
60-
atol=1e-4,
66+
rtol=rtol,
67+
atol=atol,
6168
)
6269
else:
6370
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
@@ -72,8 +79,8 @@ def test_unet(self, device, use_trace, use_ort):
7279
device=device,
7380
use_ort=use_ort,
7481
use_trace=use_trace,
75-
rtol=1e-3,
76-
atol=1e-4,
82+
rtol=rtol,
83+
atol=atol,
7784
)
7885
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
7986

@@ -107,8 +114,8 @@ def test_seg_res_net(self, device, use_ort):
107114
device=device,
108115
use_ort=use_ort,
109116
use_trace=True,
110-
rtol=1e-3,
111-
atol=1e-4,
117+
rtol=rtol,
118+
atol=atol,
112119
)
113120
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
114121

tests/test_dynunet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import platform
1415
import unittest
1516
from typing import Any, Sequence
1617

@@ -24,6 +25,12 @@
2425

2526
InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
2627

28+
ON_AARCH64 = platform.machine() == "aarch64"
29+
if ON_AARCH64:
30+
rtol, atol = 1e-2, 1e-2
31+
else:
32+
rtol, atol = 1e-4, 1e-4
33+
2734
device = "cuda" if torch.cuda.is_available() else "cpu"
2835

2936
strides: Sequence[Sequence[int] | int]
@@ -159,7 +166,7 @@ def test_consistency(self, input_param, input_shape, _):
159166
with eval_mode(net_fuser):
160167
result_fuser = net_fuser(input_tensor)
161168

162-
assert_allclose(result, result_fuser, rtol=1e-4, atol=1e-4)
169+
assert_allclose(result, result_fuser, rtol=rtol, atol=atol)
163170

164171

165172
class TestDynUNetDeepSupervision(unittest.TestCase):

tests/test_rand_affine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_rand_affine(self, input_param, input_data, expected_val):
147147
g.set_random_state(123)
148148
result = g(**input_data)
149149
g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64) # reset affine
150-
test_resampler_lazy(g, result, input_param, input_data, seed=123)
150+
test_resampler_lazy(g, result, input_param, input_data, seed=123, rtol=_rtol)
151151
if input_param.get("cache_grid", False):
152152
self.assertTrue(g._cached_grid is not None)
153153
assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor")

tests/test_rand_affined.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
234234
lazy_init_param["keys"], lazy_init_param["mode"] = key, mode
235235
resampler = RandAffined(**lazy_init_param).set_random_state(123)
236236
expected_output = resampler(**call_param)
237-
test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key)
237+
test_resampler_lazy(
238+
resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key, rtol=_rtol
239+
)
238240
resampler.lazy = False
239241

240242
if input_param.get("cache_grid", False):

tests/test_spatial_resampled.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import platform
1415
import unittest
1516

1617
import numpy as np
@@ -23,6 +24,12 @@
2324
from tests.lazy_transforms_utils import test_resampler_lazy
2425
from tests.utils import TEST_DEVICES, assert_allclose
2526

27+
ON_AARCH64 = platform.machine() == "aarch64"
28+
if ON_AARCH64:
29+
rtol, atol = 1e-1, 1e-2
30+
else:
31+
rtol, atol = 1e-3, 1e-4
32+
2633
TESTS = []
2734

2835
destinations_3d = [
@@ -104,7 +111,7 @@ def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):
104111

105112
# check lazy
106113
lazy_xform = SpatialResampled(**init_param)
107-
test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img")
114+
test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img", rtol=rtol, atol=atol)
108115

109116
# check inverse
110117
inverted = xform.inverse(output_data)["img"]

0 commit comments

Comments
 (0)