Skip to content

Commit 4ba43a9

Browse files
author
tangy5
committed
loadable options
Signed-off-by: tangy5 <yucheng.tang@vanderbilt.edu>
1 parent fefa8d1 commit 4ba43a9

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

monai/networks/blocks/text_embedding.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils import model_zoo
1717

1818
url_map = {
19-
"clip_encoding_univeral_model_31": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth",
19+
"clip_encoding_univeral_model_32": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth",
2020
}
2121

2222

@@ -37,7 +37,7 @@ def __init__(
3737
spatial_dims: int = 3,
3838
text_dim: int = 512,
3939
hidden_size: int = 256,
40-
encoding: str = "clip_embedding",
40+
encoding: str = "rand_embedding",
4141
pretrained: bool = False
4242
) -> None:
4343
"""
@@ -58,23 +58,24 @@ def __init__(
5858

5959
if self.encoding == 'rand_embedding':
6060
self.text_embedding = nn.Embedding(out_channels, hidden_size)
61-
elif self.encoding == 'clip_embedding':
62-
self.register_buffer('text_embedding', torch.randn(out_channels, text_dim))
63-
if pretrained:
64-
model_url = url_map["clip_encoding_univeral_model_31"]
65-
pretrain_state_dict = model_zoo.load_url(model_url)
66-
self.text_embedding.data = pretrain_state_dict.float()
67-
print('load word embedding: {}'.format(self.encoding))
68-
self.text_to_vision = nn.Linear(text_dim, hidden_size)
6961
else:
70-
raise Exception(f'{self.encoding} is not implemented, please add your own')
62+
if self.encoding in url_map:
63+
self.register_buffer('text_embedding', torch.randn(out_channels, text_dim))
64+
if pretrained:
65+
model_url = url_map[self.encoding]
66+
pretrain_state_dict = model_zoo.load_url(model_url)
67+
self.text_embedding.data = pretrain_state_dict.float()
68+
print('load text embedding: {}'.format(self.encoding))
69+
self.text_to_vision = nn.Linear(text_dim, hidden_size)
70+
else:
71+
raise Exception(f'{self.encoding} is not implemented, please add your own')
7172

7273
def forward(self):
73-
if self.encoding == 'clip_embedding':
74-
test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding))
75-
else:
74+
if self.encoding == 'rand_embedding':
7675
# text embedding as random initialized 'rand_embedding'
7776
test_encoding = self.text_embedding.weight
77+
else:
78+
test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding))
7879

7980
if self.spatial_dims == 3:
8081
test_encoding = test_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2)

0 commit comments

Comments
 (0)