Skip to content

Commit 9c11697

Browse files
committed
migrate existing pytorch tests to isolated directory
1 parent ac3df64 commit 9c11697

18 files changed

+190
-176
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
from torch2trt import *
1+
import pytest
2+
import torch2trt
23
import torchvision
34
import torch
4-
from .segmentation import deeplabv3_resnet50
55

66

7-
if __name__ == '__main__':
8-
model = deeplabv3_resnet50().cuda().eval().half()
7+
def test_save_load():
8+
model = torch.nn.Conv2d(3, 3, 1).cuda().eval().half()
99
data = torch.randn((1, 3, 224, 224)).cuda().half()
1010

1111
print('Running torch2trt...')
12-
model_trt = torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
12+
model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
1313

1414
print('Saving model...')
1515
torch.save(model_trt.state_dict(), '.test_model.pth')
1616

1717
print('Loading model...')
18-
model_trt_2 = TRTModule()
18+
model_trt_2 = torch2trt.TRTModule()
1919
model_trt_2.load_state_dict(torch.load('.test_model.pth'))
2020

2121
assert(model_trt_2.engine is not None)

0 commit comments

Comments
 (0)