|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import torch |
| 15 | +from torch import nn |
| 16 | +from torch.utils import model_zoo |
| 17 | + |
| 18 | +url_map = { |
| 19 | + "clip_encoding_univeral_model_32": ( |
| 20 | + "https://github.com/Project-MONAI/MONAI-extra-test-data/" |
| 21 | + "releases/download/0.8.1/clip_encoding_univeral_model.pth" |
| 22 | + ) |
| 23 | +} |
| 24 | + |
| 25 | + |
| 26 | +class TextEncoder(nn.Module): |
| 27 | + """ |
| 28 | + Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding. |
| 29 | + The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models. |
| 30 | +
|
| 31 | + Contrastive Language-Image Pre-training (CLIP), based on: "Radford et al., |
| 32 | + Learning Transferable Visual Models From Natural Language Supervision <https://arxiv.org/abs/2103.00020>" |
| 33 | +
|
| 34 | + Connecting text and medical 3D image, based on: "Liu et al., |
| 35 | + CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>" |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + out_channels: int, |
| 41 | + spatial_dims: int = 3, |
| 42 | + text_dim: int = 512, |
| 43 | + hidden_size: int = 256, |
| 44 | + encoding: str = "clip_encoding_univeral_model_32", |
| 45 | + pretrained: bool = True, |
| 46 | + ) -> None: |
| 47 | + """ |
| 48 | + Args: |
| 49 | + out_channels: number of output channels, to control text-baesd embedding for classes. |
| 50 | + spatial_dims: number of spatial dims. |
| 51 | + text_dim: dimension of text embeddings. |
| 52 | + hidden_size: dimension of hidden features, compatible to different vision feature dimensions. |
| 53 | + encoding: the text embedding type, default to use clip text pretrained weights. |
| 54 | + pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False. |
| 55 | + """ |
| 56 | + super().__init__() |
| 57 | + self.encoding = encoding |
| 58 | + |
| 59 | + self.spatial_dims = spatial_dims |
| 60 | + if spatial_dims not in (2, 3): |
| 61 | + raise ValueError("spatial dimension should be 2 or 3.") |
| 62 | + |
| 63 | + if self.encoding == "rand_embedding": |
| 64 | + self.text_embedding = nn.Embedding(out_channels, hidden_size) |
| 65 | + else: |
| 66 | + self.register_buffer("text_embedding", torch.randn(out_channels, text_dim)) |
| 67 | + |
| 68 | + if pretrained: |
| 69 | + model_url = url_map[self.encoding] |
| 70 | + pretrain_state_dict = model_zoo.load_url(model_url, map_location="cpu") |
| 71 | + self.text_embedding.data = pretrain_state_dict.float() # type: ignore |
| 72 | + else: |
| 73 | + print(f"{self.encoding} is not implemented, and can not be downloaded, please load your own") |
| 74 | + |
| 75 | + self.text_to_vision = nn.Linear(text_dim, hidden_size) |
| 76 | + |
| 77 | + def forward(self): |
| 78 | + if self.encoding == "rand_embedding": |
| 79 | + # text embedding as random initialized 'rand_embedding' |
| 80 | + text_embedding = self.text_embedding.weight |
| 81 | + else: |
| 82 | + print(self.text_embedding) |
| 83 | + text_embedding = nn.functional.relu(self.text_to_vision(self.text_embedding)) |
| 84 | + |
| 85 | + if self.spatial_dims == 3: |
| 86 | + text_embedding = text_embedding.unsqueeze(2).unsqueeze(2).unsqueeze(2) |
| 87 | + elif self.spatial_dims == 2: |
| 88 | + text_embedding = text_embedding.unsqueeze(2).unsqueeze(2) |
| 89 | + |
| 90 | + return text_embedding |
0 commit comments