Skip to content

Commit af30d79

Browse files
author
tangy5
committed
add skip if downloading fails
Signed-off-by: tangy5 <yucheng.tang@vanderbilt.edu>
1 parent 6b3d8fe commit af30d79

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

tests/test_text_encoding.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,33 @@
1515

1616
import torch
1717
from monai.networks.blocks.text_embedding import TextEncoder
18+
from tests.utils import skip_if_downloading_fails
1819

1920
device = "cuda" if torch.cuda.is_available() else "cpu"
2021

2122

2223
class TestTextEncoder(unittest.TestCase):
2324
def test_test_encoding_shape(self):
24-
# test 2D encoder
25-
text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device)
26-
text_encoding = text_encoder()
27-
print(text_encoding.shape)
28-
self.assertEqual(text_encoding.shape, (32,256,1,1))
29-
30-
# test 3D encoder
31-
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device)
25+
with skip_if_downloading_fails():
26+
# test 2D encoder
27+
text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device)
28+
text_encoding = text_encoder()
29+
self.assertEqual(text_encoding.shape, (32,256,1,1))
30+
31+
# test 3D encoder
32+
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device)
33+
text_encoding = text_encoder()
34+
self.assertEqual(text_encoding.shape, (32,256,1,1,1))
35+
36+
# test random enbedding 3D
37+
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device)
3238
text_encoding = text_encoder()
33-
print(text_encoding.shape)
3439
self.assertEqual(text_encoding.shape, (32,256,1,1,1))
3540

36-
# test random enbedding
37-
text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device)
41+
# test random enbedding 2D
42+
text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True).to(device)
3843
text_encoding = text_encoder()
39-
print(text_encoding.shape)
40-
self.assertEqual(text_encoding.shape, (32,256,1,1,1))
44+
self.assertEqual(text_encoding.shape, (32,256,1,1))
4145

4246
if __name__ == "__main__":
4347
unittest.main()

0 commit comments

Comments
 (0)