16
16
from torch .utils import model_zoo
17
17
18
18
url_map = {
19
- "clip_encoding_univeral_model_31 " : "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth" ,
19
+ "clip_encoding_univeral_model_32 " : "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth" ,
20
20
}
21
21
22
22
@@ -37,7 +37,7 @@ def __init__(
37
37
spatial_dims : int = 3 ,
38
38
text_dim : int = 512 ,
39
39
hidden_size : int = 256 ,
40
- encoding : str = "clip_embedding " ,
40
+ encoding : str = "rand_embedding " ,
41
41
pretrained : bool = False
42
42
) -> None :
43
43
"""
@@ -58,23 +58,24 @@ def __init__(
58
58
59
59
if self .encoding == 'rand_embedding' :
60
60
self .text_embedding = nn .Embedding (out_channels , hidden_size )
61
- elif self .encoding == 'clip_embedding' :
62
- self .register_buffer ('text_embedding' , torch .randn (out_channels , text_dim ))
63
- if pretrained :
64
- model_url = url_map ["clip_encoding_univeral_model_31" ]
65
- pretrain_state_dict = model_zoo .load_url (model_url )
66
- self .text_embedding .data = pretrain_state_dict .float ()
67
- print ('load word embedding: {}' .format (self .encoding ))
68
- self .text_to_vision = nn .Linear (text_dim , hidden_size )
69
61
else :
70
- raise Exception (f'{ self .encoding } is not implemented, please add your own' )
62
+ if self .encoding in url_map :
63
+ self .register_buffer ('text_embedding' , torch .randn (out_channels , text_dim ))
64
+ if pretrained :
65
+ model_url = url_map [self .encoding ]
66
+ pretrain_state_dict = model_zoo .load_url (model_url )
67
+ self .text_embedding .data = pretrain_state_dict .float ()
68
+ print ('load text embedding: {}' .format (self .encoding ))
69
+ self .text_to_vision = nn .Linear (text_dim , hidden_size )
70
+ else :
71
+ raise Exception (f'{ self .encoding } is not implemented, please add your own' )
71
72
72
73
def forward (self ):
73
- if self .encoding == 'clip_embedding' :
74
- test_encoding = nn .functional .relu (self .text_to_vision (self .text_embedding ))
75
- else :
74
+ if self .encoding == 'rand_embedding' :
76
75
# text embedding as random initialized 'rand_embedding'
77
76
test_encoding = self .text_embedding .weight
77
+ else :
78
+ test_encoding = nn .functional .relu (self .text_to_vision (self .text_embedding ))
78
79
79
80
if self .spatial_dims == 3 :
80
81
test_encoding = test_encoding .unsqueeze (2 ).unsqueeze (2 ).unsqueeze (2 )
0 commit comments