Skip to content

Commit 57c618c

Browse files
authored
Add text to vision embedding (#6282)
As part of the text to vision encoder for medical image analysis. Support CLIP pre-trained embedding and random text embedding. Linked to the issue: #6177 --------- Signed-off-by: tangy5 <yucheng.tang@vanderbilt.edu>
1 parent d8d887f commit 57c618c

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
from torch import nn
16+
from torch.utils import model_zoo
17+
18+
url_map = {
19+
"clip_encoding_univeral_model_32": (
20+
"https://github.com/Project-MONAI/MONAI-extra-test-data/"
21+
"releases/download/0.8.1/clip_encoding_univeral_model.pth"
22+
)
23+
}
24+
25+
26+
class TextEncoder(nn.Module):
27+
"""
28+
Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding.
29+
The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models.
30+
31+
Contrastive Language-Image Pre-training (CLIP), based on: "Radford et al.,
32+
Learning Transferable Visual Models From Natural Language Supervision <https://arxiv.org/abs/2103.00020>"
33+
34+
Connecting text and medical 3D image, based on: "Liu et al.,
35+
CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
36+
"""
37+
38+
def __init__(
39+
self,
40+
out_channels: int,
41+
spatial_dims: int = 3,
42+
text_dim: int = 512,
43+
hidden_size: int = 256,
44+
encoding: str = "clip_encoding_univeral_model_32",
45+
pretrained: bool = True,
46+
) -> None:
47+
"""
48+
Args:
49+
out_channels: number of output channels, to control text-baesd embedding for classes.
50+
spatial_dims: number of spatial dims.
51+
text_dim: dimension of text embeddings.
52+
hidden_size: dimension of hidden features, compatible to different vision feature dimensions.
53+
encoding: the text embedding type, default to use clip text pretrained weights.
54+
pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False.
55+
"""
56+
super().__init__()
57+
self.encoding = encoding
58+
59+
self.spatial_dims = spatial_dims
60+
if spatial_dims not in (2, 3):
61+
raise ValueError("spatial dimension should be 2 or 3.")
62+
63+
if self.encoding == "rand_embedding":
64+
self.text_embedding = nn.Embedding(out_channels, hidden_size)
65+
else:
66+
self.register_buffer("text_embedding", torch.randn(out_channels, text_dim))
67+
68+
if pretrained:
69+
model_url = url_map[self.encoding]
70+
pretrain_state_dict = model_zoo.load_url(model_url, map_location="cpu")
71+
self.text_embedding.data = pretrain_state_dict.float() # type: ignore
72+
else:
73+
print(f"{self.encoding} is not implemented, and can not be downloaded, please load your own")
74+
75+
self.text_to_vision = nn.Linear(text_dim, hidden_size)
76+
77+
def forward(self):
78+
if self.encoding == "rand_embedding":
79+
# text embedding as random initialized 'rand_embedding'
80+
text_embedding = self.text_embedding.weight
81+
else:
82+
print(self.text_embedding)
83+
text_embedding = nn.functional.relu(self.text_to_vision(self.text_embedding))
84+
85+
if self.spatial_dims == 3:
86+
text_embedding = text_embedding.unsqueeze(2).unsqueeze(2).unsqueeze(2)
87+
elif self.spatial_dims == 2:
88+
text_embedding = text_embedding.unsqueeze(2).unsqueeze(2)
89+
90+
return text_embedding

tests/test_text_encoding.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
from monai.networks.blocks.text_embedding import TextEncoder
17+
from tests.utils import skip_if_downloading_fails
18+
19+
20+
class TestTextEncoder(unittest.TestCase):
21+
def test_test_encoding_shape(self):
22+
with skip_if_downloading_fails():
23+
# test 2D encoder
24+
text_encoder = TextEncoder(
25+
spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True
26+
)
27+
text_encoding = text_encoder()
28+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1))
29+
30+
# test 3D encoder
31+
text_encoder = TextEncoder(
32+
spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True
33+
)
34+
text_encoding = text_encoder()
35+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))
36+
37+
# test random enbedding 3D
38+
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True)
39+
text_encoding = text_encoder()
40+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))
41+
42+
# test random enbedding 2D
43+
text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True)
44+
text_encoding = text_encoder()
45+
self.assertEqual(text_encoding.shape, (32, 256, 1, 1))
46+
47+
48+
if __name__ == "__main__":
49+
unittest.main()

0 commit comments

Comments
 (0)