|
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | from monai.networks.blocks.text_embedding import TextEncoder |
| 18 | +from tests.utils import skip_if_downloading_fails |
18 | 19 |
|
19 | 20 | device = "cuda" if torch.cuda.is_available() else "cpu" |
20 | 21 |
|
21 | 22 |
|
22 | 23 | class TestTextEncoder(unittest.TestCase): |
23 | 24 | def test_test_encoding_shape(self): |
24 | | - # test 2D encoder |
25 | | - text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) |
26 | | - text_encoding = text_encoder() |
27 | | - print(text_encoding.shape) |
28 | | - self.assertEqual(text_encoding.shape, (32,256,1,1)) |
29 | | - |
30 | | - # test 3D encoder |
31 | | - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) |
| 25 | + with skip_if_downloading_fails(): |
| 26 | + # test 2D encoder |
| 27 | + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) |
| 28 | + text_encoding = text_encoder() |
| 29 | + self.assertEqual(text_encoding.shape, (32,256,1,1)) |
| 30 | + |
| 31 | + # test 3D encoder |
| 32 | + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) |
| 33 | + text_encoding = text_encoder() |
| 34 | + self.assertEqual(text_encoding.shape, (32,256,1,1,1)) |
| 35 | + |
| 36 | + # test random enbedding 3D |
| 37 | + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) |
32 | 38 | text_encoding = text_encoder() |
33 | | - print(text_encoding.shape) |
34 | 39 | self.assertEqual(text_encoding.shape, (32,256,1,1,1)) |
35 | 40 |
|
36 | | - # test random enbedding |
37 | | - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) |
| 41 | + # test random enbedding 2D |
| 42 | + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) |
38 | 43 | text_encoding = text_encoder() |
39 | | - print(text_encoding.shape) |
40 | | - self.assertEqual(text_encoding.shape, (32,256,1,1,1)) |
| 44 | + self.assertEqual(text_encoding.shape, (32,256,1,1)) |
41 | 45 |
|
42 | 46 | if __name__ == "__main__": |
43 | 47 | unittest.main() |
0 commit comments