Skip to content

Add text to vision embedding #6282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7f255ad
Add text to vision embedding
Apr 4, 2023
2dd5566
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2023
1481992
Merge branch 'Project-MONAI:dev' into textembedding
tangy5 Apr 4, 2023
81c2965
update parameters
Apr 4, 2023
f84ac5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2023
d737121
update encoding
Apr 5, 2023
ae7c2fe
change file mode
Apr 5, 2023
bd4dc37
fix flake8 format
Apr 5, 2023
3fd70e3
fix flake8 format2
Apr 5, 2023
a65d6c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2023
1ffd1b2
Merge branch 'Project-MONAI:dev' into textembedding
tangy5 Apr 6, 2023
79a8f85
update var name
Apr 6, 2023
2cc0ce4
update var name
Apr 6, 2023
88f392f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2023
b5604df
Merge branch 'dev' into textembedding
tangy5 Apr 11, 2023
fefa8d1
update 2d case, pretrain option, release CLIP weights
Apr 12, 2023
4ba43a9
loadable options
Apr 12, 2023
6b3d8fe
remove print
Apr 12, 2023
af30d79
add skip if downloading fails
Apr 12, 2023
900729d
update pretrained load logic
Apr 12, 2023
5e8de2c
Merge branch 'dev' into textembedding
tangy5 Apr 12, 2023
2be5867
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
38696a5
Merge branch 'Project-MONAI:dev' into textembedding
tangy5 Apr 13, 2023
c2a1755
fix cpu only test and others
Apr 13, 2023
4392a46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2023
447509b
[MONAI] code formatting
monai-bot Apr 13, 2023
bd6c4e3
fixes
wyli Apr 13, 2023
72bf74a
fixes
wyli Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions monai/networks/blocks/text_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
from torch import nn
from torch.utils import model_zoo

url_map = {
"clip_encoding_univeral_model_32": (
"https://github.com/Project-MONAI/MONAI-extra-test-data/"
"releases/download/0.8.1/clip_encoding_univeral_model.pth"
)
}


class TextEncoder(nn.Module):
"""
Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding.
The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models.

Contrastive Language-Image Pre-training (CLIP), based on: "Radford et al.,
Learning Transferable Visual Models From Natural Language Supervision <https://arxiv.org/abs/2103.00020>"

Connecting text and medical 3D image, based on: "Liu et al.,
CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
"""

def __init__(
self,
out_channels: int,
spatial_dims: int = 3,
text_dim: int = 512,
hidden_size: int = 256,
encoding: str = "clip_encoding_univeral_model_32",
pretrained: bool = True,
) -> None:
"""
Args:
out_channels: number of output channels, to control text-baesd embedding for classes.
spatial_dims: number of spatial dims.
text_dim: dimension of text embeddings.
hidden_size: dimension of hidden features, compatible to different vision feature dimensions.
encoding: the text embedding type, default to use clip text pretrained weights.
pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False.
"""
super().__init__()
self.encoding = encoding

self.spatial_dims = spatial_dims
if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")

if self.encoding == "rand_embedding":
self.text_embedding = nn.Embedding(out_channels, hidden_size)
else:
self.register_buffer("text_embedding", torch.randn(out_channels, text_dim))

if pretrained:
model_url = url_map[self.encoding]
pretrain_state_dict = model_zoo.load_url(model_url, map_location="cpu")
self.text_embedding.data = pretrain_state_dict.float() # type: ignore
else:
print(f"{self.encoding} is not implemented, and can not be downloaded, please load your own")

self.text_to_vision = nn.Linear(text_dim, hidden_size)

def forward(self):
if self.encoding == "rand_embedding":
# text embedding as random initialized 'rand_embedding'
text_embedding = self.text_embedding.weight
else:
print(self.text_embedding)
text_embedding = nn.functional.relu(self.text_to_vision(self.text_embedding))

if self.spatial_dims == 3:
text_embedding = text_embedding.unsqueeze(2).unsqueeze(2).unsqueeze(2)
elif self.spatial_dims == 2:
text_embedding = text_embedding.unsqueeze(2).unsqueeze(2)

return text_embedding
49 changes: 49 additions & 0 deletions tests/test_text_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

from monai.networks.blocks.text_embedding import TextEncoder
from tests.utils import skip_if_downloading_fails


class TestTextEncoder(unittest.TestCase):
def test_test_encoding_shape(self):
with skip_if_downloading_fails():
# test 2D encoder
text_encoder = TextEncoder(
spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True
)
text_encoding = text_encoder()
self.assertEqual(text_encoding.shape, (32, 256, 1, 1))

# test 3D encoder
text_encoder = TextEncoder(
spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True
)
text_encoding = text_encoder()
self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))

# test random enbedding 3D
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True)
text_encoding = text_encoder()
self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))

# test random enbedding 2D
text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True)
text_encoding = text_encoder()
self.assertEqual(text_encoding.shape, (32, 256, 1, 1))


if __name__ == "__main__":
unittest.main()