Skip to content

Commit 6243031

Browse files
myronpre-commit-ci[bot]KumoLiu
authored
Adding a network CellSamWrapper (#7981)
Adding a network CellSamWrapper, a thin wrapper around SAM, which can be used for 2D segmentation tasks. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: am <am> Signed-off-by: myron <amyronenko@nvidia.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: am <am> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent f848002 commit 6243031

File tree

5 files changed

+157
-3
lines changed

5 files changed

+157
-3
lines changed

docs/source/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
254254
- The options are
255255

256256
```
257-
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]
257+
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, segment-anything]
258258
```
259259

260260
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
261-
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.
261+
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `segment-anything` respectively.
262262

263263
- `pip install 'monai[all]'` installs all the optional dependencies.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
from torch import nn
16+
from torch.nn import functional as F
17+
18+
from monai.utils import optional_import
19+
20+
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
21+
22+
_all__ = ["CellSamWrapper"]
23+
24+
25+
class CellSamWrapper(torch.nn.Module):
26+
"""
27+
CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything
28+
with an image only decoder, that can be used for segmentation tasks.
29+
30+
31+
Args:
32+
auto_resize_inputs: whether to resize inputs before passing to the network.
33+
(usually they need be resized, unless they are already at the expected size)
34+
network_resize_roi: expected input size for the network.
35+
(currently SAM expects 1024x1024)
36+
checkpoint: checkpoint file to load the SAM weights from.
37+
(this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
38+
return_features: whether to return features from SAM encoder
39+
(without using decoder/upsampling to the original input size)
40+
41+
"""
42+
43+
def __init__(
44+
self,
45+
auto_resize_inputs=True,
46+
network_resize_roi=(1024, 1024),
47+
checkpoint="sam_vit_b_01ec64.pth",
48+
return_features=False,
49+
*args,
50+
**kwargs,
51+
) -> None:
52+
super().__init__(*args, **kwargs)
53+
54+
self.network_resize_roi = network_resize_roi
55+
self.auto_resize_inputs = auto_resize_inputs
56+
self.return_features = return_features
57+
58+
if not has_sam:
59+
raise ValueError(
60+
"SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git"
61+
)
62+
63+
model = build_sam_vit_b(checkpoint=checkpoint)
64+
65+
model.prompt_encoder = None
66+
model.mask_decoder = None
67+
68+
model.mask_decoder = nn.Sequential(
69+
nn.BatchNorm2d(num_features=256),
70+
nn.ReLU(inplace=True),
71+
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
72+
nn.BatchNorm2d(num_features=128),
73+
nn.ReLU(inplace=True),
74+
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
75+
)
76+
77+
self.model = model
78+
79+
def forward(self, x):
80+
sh = x.shape[2:]
81+
82+
if self.auto_resize_inputs:
83+
x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear")
84+
85+
x = self.model.image_encoder(x)
86+
87+
if not self.return_features:
88+
x = self.model.mask_decoder(x)
89+
if self.auto_resize_inputs:
90+
x = F.interpolate(x, size=sh, mode="bilinear")
91+
92+
return x

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ nvidia-ml-py
5959
huggingface_hub
6060
pyamg>=5.0.0
6161
git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd
62+
git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588

setup.cfg

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ all =
8585
nvidia-ml-py
8686
huggingface_hub
8787
pyamg>=5.0.0
88+
segment-anything
8889
nibabel =
8990
nibabel
9091
ninja =
@@ -162,11 +163,13 @@ pynvml =
162163
nvidia-ml-py
163164
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
164165
# MetricsReloaded =
165-
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
166+
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
166167
huggingface_hub =
167168
huggingface_hub
168169
pyamg =
169170
pyamg>=5.0.0
171+
segment-anything =
172+
segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything
170173

171174
[flake8]
172175
select = B,C,E,F,N,P,T4,W,B9

tests/test_cell_sam_wrapper.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks import eval_mode
20+
from monai.networks.nets.cell_sam_wrapper import CellSamWrapper
21+
from monai.utils import optional_import
22+
23+
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
24+
25+
device = "cuda" if torch.cuda.is_available() else "cpu"
26+
TEST_CASE_CELLSEGWRAPPER = []
27+
for dims in [128, 256, 512, 1024]:
28+
test_case = [
29+
{"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None},
30+
(1, 3, *([dims] * 2)),
31+
(1, 3, *([dims] * 2)),
32+
]
33+
TEST_CASE_CELLSEGWRAPPER.append(test_case)
34+
35+
36+
@unittest.skipUnless(has_sam, "Requires SAM installation")
37+
class TestResNetDS(unittest.TestCase):
38+
39+
@parameterized.expand(TEST_CASE_CELLSEGWRAPPER)
40+
def test_shape(self, input_param, input_shape, expected_shape):
41+
net = CellSamWrapper(**input_param).to(device)
42+
with eval_mode(net):
43+
result = net(torch.randn(input_shape).to(device))
44+
self.assertEqual(result.shape, expected_shape, msg=str(input_param))
45+
46+
def test_ill_arg0(self):
47+
with self.assertRaises(RuntimeError):
48+
net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device)
49+
net(torch.randn([1, 3, 256, 256]).to(device))
50+
51+
def test_ill_arg1(self):
52+
with self.assertRaises(RuntimeError):
53+
net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device)
54+
net(torch.randn([1, 3, 1024, 1024]).to(device))
55+
56+
57+
if __name__ == "__main__":
58+
unittest.main()

0 commit comments

Comments
 (0)