-
Notifications
You must be signed in to change notification settings - Fork 350
/
test_models_export.py
303 lines (256 loc) · 9.81 KB
/
test_models_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# type: ignore
import importlib
import unittest
from importlib import metadata
import pytest
import timm
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from transformers import BertModel
from transformers.utils.fx import symbolic_trace as transformers_trace
from packaging.version import Version
assertions = unittest.TestCase()
@pytest.mark.unit
def test_resnet18(ir):
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 8,
"cache_built_engines": False,
"reuse_cached_engines": False,
}
trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
@pytest.mark.unit
def test_mobilenet_v2(ir):
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 8,
"cache_built_engines": False,
"reuse_cached_engines": False,
}
trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
@pytest.mark.unit
def test_efficientnet_b0(ir):
model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 8,
"cache_built_engines": False,
"reuse_cached_engines": False,
}
trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
@pytest.mark.unit
def test_bert_base_uncased(ir):
model = (
BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
)
input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
compile_spec = {
"inputs": [
torchtrt.Input(
input.shape,
dtype=input.dtype,
format=torch.contiguous_format,
),
torchtrt.Input(
input.shape,
dtype=input.dtype,
format=torch.contiguous_format,
),
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"truncate_double": True,
"ir": ir,
"min_block_size": 10,
"cache_built_engines": False,
"reuse_cached_engines": False,
}
trt_mod = torchtrt.compile(model, **compile_spec)
model_outputs = model(input, input2)
trt_model_outputs = trt_mod(input, input2)
assertions.assertTrue(
len(model_outputs) == len(trt_model_outputs),
msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.",
)
for index in range(len(model_outputs)):
out, trt_out = model_outputs[index], trt_model_outputs[index]
cos_sim = cosine_similarity(out, trt_out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
@pytest.mark.unit
def test_resnet18_half(ir):
model = models.resnet18(pretrained=True).eval().to("cuda").half()
input = torch.randn((1, 3, 224, 224)).to("cuda").half()
compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.half, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.half},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 8,
"cache_built_engines": False,
"reuse_cached_engines": False,
}
trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
@unittest.skipIf(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"FP8 compilation in Torch-TRT is not supported on cards older than Hopper",
)
@unittest.skipIf(
not importlib.util.find_spec("modelopt"),
reason="ModelOpt is necessary to run this test",
)
@pytest.mark.unit
def test_base_fp8(ir):
import modelopt
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.ReLU()(x)
x = self.linear2(x)
return x
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
input_tensor = torch.randn(1, 10).cuda()
model = SimpleNetwork().eval().cuda()
quant_cfg = mtq.FP8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
output_pyt = model(input_tensor)
with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.float8_e4m3fn},
min_block_size=1,
debug=True,
cache_built_engines=False,
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
@unittest.skipIf(
not importlib.util.find_spec("modelopt")
or Version(metadata.version("nvidia-modelopt")) < Version("0.16.1"),
"modelopt 0.16.1 or later is required Int8 quantization is supported in modelopt since 0.16.1 or later",
)
@pytest.mark.unit
def test_base_int8(ir):
import modelopt
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
def forward(self, x):
x = self.linear1(x)
x = torch.nn.ReLU()(x)
x = self.linear2(x)
return x
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)
input_tensor = torch.randn(1, 10).cuda()
model = SimpleNetwork().eval().cuda()
quant_cfg = mtq.INT8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has INT8 qdq nodes at this point
output_pyt = model(input_tensor)
with torch.no_grad():
with export_torch_mode():
from torch.export._trace import _export
exp_program = _export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.int8},
min_block_size=1,
debug=True,
cache_built_engines=False,
reuse_cached_engines=False,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)