Skip to content

Commit 4ce9de5

Browse files
authored
Enhance ONNXRT backend check (#1160)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 775deff commit 4ce9de5

File tree

3 files changed

+61
-23
lines changed

3 files changed

+61
-23
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,13 @@ def __init__(self, framework_specific_info):
6464
self.dynamic = framework_specific_info["approach"] == "post_training_dynamic_quant"
6565
self.domain = framework_specific_info.get("domain", "auto")
6666
self.recipes = framework_specific_info.get("recipes", {})
67+
self._check_backend_available(framework_specific_info["backend"])
6768
self.backend = PROVIDERS[framework_specific_info["backend"]]
6869
self.performance_only = framework_specific_info.get("performance_only", False)
6970
self.use_bf16 = framework_specific_info.get("use_bf16", False) and \
7071
self.backend in ort.get_available_providers()
7172
self.use_fp16 = framework_specific_info.get("use_fp16", False)
7273

73-
if self.backend not in ort.get_all_providers():
74-
logger.warning("{} backend is not supported in current environment, "
75-
"supported backends: {}".format(ONNXRT_BACKENDS[self.backend],
76-
[ONNXRT_BACKENDS[i] for i in ort.get_all_providers() if i in ONNXRT_BACKENDS]))
77-
7874
# get quantization format according to framework_specific_info
7975
if (not self.dynamic and "format" in framework_specific_info and \
8076
framework_specific_info["format"].lower() == 'qdq') or \
@@ -324,6 +320,23 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
324320
tmp_model.topological_sort()
325321
return tmp_model
326322

323+
def _check_backend_available(self, backend):
324+
"""Check backend is available or not."""
325+
if backend not in PROVIDERS:
326+
assert False, "'{}' backend is not supported, " \
327+
"supported backends include {}".format(backend, \
328+
[provider for provider in PROVIDERS.keys()])
329+
330+
if backend in ["onnxrt_trt_ep", "onnxrt_cuda_ep"] and \
331+
self.device != "gpu":
332+
logger.warning("Backend `{}` requires a GPU device. Reset device to 'gpu'.".format(backend))
333+
self.device = "gpu"
334+
335+
ep = PROVIDERS[backend]
336+
if ep not in ort.get_available_providers():
337+
logger.warning("Specified provider '{}' is not in available provider names. "\
338+
"Fallback to available providers: '{}'".format(ep, ", ".join(ort.get_available_providers())))
339+
327340
def _reset_calib_iter(self, data_loader, cfg_calib_sampling_size, cfg_calib_iter):
328341
"""Check and reset calibration iterations according to calib_sampleing_size and dataloader batch_size."""
329342
if isinstance(data_loader, BaseDataLoader):

test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import shutil
33
import unittest
4+
from unittest.mock import patch
45
import onnxruntime as ort
56
import torch
67
import torchvision
@@ -1471,6 +1472,32 @@ def test_dataloader_input(self):
14711472
q_model = quantizer.fit()
14721473
self.assertNotEqual(q_model, None)
14731474

1475+
@patch('logging.Logger.warning')
1476+
def test_backend(self, mock_warning):
1477+
framework_specific_info = {"device": "cpu",
1478+
"backend": "test_backend",
1479+
"approach": "post_training_static_quant",
1480+
"workspace_path": './nc_workspace'}
1481+
framework = "onnxrt_qlinearops"
1482+
with self.assertRaises(AssertionError) as context:
1483+
adaptor = FRAMEWORKS[framework](framework_specific_info)
1484+
self.assertEqual(str(context.exception), "'test_backend' backend is not supported, "\
1485+
"supported backends include ['default', 'onnxrt_trt_ep', 'onnxrt_dnnl_ep', 'onnxrt_cuda_ep']")
1486+
1487+
framework_specific_info = {"device": "cpu",
1488+
"backend": "onnxrt_trt_ep",
1489+
"approach": "post_training_static_quant",
1490+
"workspace_path": './nc_workspace'}
1491+
framework = "onnxrt_qlinearops"
1492+
adaptor = FRAMEWORKS[framework](framework_specific_info)
1493+
1494+
call_args_list = mock_warning.call_args_list
1495+
first_warning_args = call_args_list[0][0]
1496+
self.assertEqual(first_warning_args[0], "Backend `onnxrt_trt_ep` requires a GPU device. Reset device to 'gpu'.")
1497+
second_warning_args = call_args_list[1][0]
1498+
self.assertIn("not in available provider names. Fallback to available providers", second_warning_args[0])
1499+
1500+
self.assertEqual(mock_warning.call_count, 2)
14741501

14751502

14761503
if __name__ == "__main__":

test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,8 +1377,6 @@ def get_fp16_mixed_precision_model(self, model):
13771377
converted_model = fit(model, config)
13781378
return converted_model
13791379

1380-
@unittest.skipIf('CUDAExecutionProvider' not in ort.get_all_providers(),
1381-
"skip since CUDAExecutionProvider is not supported")
13821380
def test_fp16(self):
13831381
optypes = ['Sum', 'Sub', 'Div', 'Pow', 'Add']
13841382
for optype in optypes:
@@ -1391,7 +1389,7 @@ def test_fp16(self):
13911389
convert_model = self.get_fp16_mixed_precision_model(model)
13921390
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
13931391
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1394-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1392+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
13951393
outputs = session.run(None, input_data)
13961394

13971395
optypes = ['Equal', 'Greater', 'GreaterOrEqual', 'Less', 'LessOrEqual']
@@ -1405,7 +1403,7 @@ def test_fp16(self):
14051403
convert_model = self.get_fp16_mixed_precision_model(model)
14061404
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14071405
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1408-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1406+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14091407
outputs = session.run(None, input_data)
14101408

14111409
optypes = ['Abs', 'Exp', 'Log', 'Round', 'Sqrt', 'Softmax', 'Exp', 'Tanh', 'Sigmoid', 'LeakyRelu', 'Round']
@@ -1418,7 +1416,7 @@ def test_fp16(self):
14181416
convert_model = self.get_fp16_mixed_precision_model(model)
14191417
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14201418
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1421-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1419+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14221420
outputs = session.run(None, input_data)
14231421

14241422
optypes = ['ReduceMean', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp', 'ReduceMax', 'ReduceProd', \
@@ -1432,7 +1430,7 @@ def test_fp16(self):
14321430
convert_model = self.get_fp16_mixed_precision_model(model)
14331431
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14341432
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1435-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1433+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14361434
outputs = session.run(None, input_data)
14371435

14381436
optypes = ['Gelu']
@@ -1445,7 +1443,7 @@ def test_fp16(self):
14451443
convert_model = self.get_fp16_mixed_precision_model(model)
14461444
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14471445
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1448-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1446+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14491447
outputs = session.run(None, input_data)
14501448

14511449
optypes = ['BiasGelu', 'FastGelu']
@@ -1459,7 +1457,7 @@ def test_fp16(self):
14591457
convert_model = self.get_fp16_mixed_precision_model(model)
14601458
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14611459
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1462-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1460+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14631461
outputs = session.run(None, input_data)
14641462

14651463

@@ -1474,7 +1472,7 @@ def test_fp16(self):
14741472
convert_model = self.get_fp16_mixed_precision_model(model)
14751473
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14761474
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1477-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1475+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14781476
outputs = session.run(None, input_data)
14791477

14801478
optypes = ['FusedMatMul']
@@ -1489,22 +1487,22 @@ def test_fp16(self):
14891487
convert_model = self.get_fp16_mixed_precision_model(model)
14901488
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
14911489
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1492-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1490+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
14931491
outputs = session.run(None, input_data)
14941492

14951493
optypes = ['Gemm']
14961494
for optype in optypes:
14971495
inps = [['input1', TensorProto.FLOAT, (1,2)]]
1498-
outs = [['output', TensorProto.FLOAT, (1,2)]]
1496+
outs = [['output', TensorProto.FLOAT, (1,1)]]
14991497
weights = [['input2', TensorProto.FLOAT, (2,1), np.random.random((2))],
1500-
['input3', TensorProto.FLOAT, (1,2), np.random.random((2))]]
1498+
['input3', TensorProto.FLOAT, (1,1), np.random.random((1))]]
15011499
node_infos = [['test', ['input1', 'input2', 'input3'], ['output'], optype]]
15021500
model = self.build_model(inps, outs, weights, node_infos)
15031501
input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
15041502
convert_model = self.get_fp16_mixed_precision_model(model)
15051503
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
15061504
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1507-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1505+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
15081506
outputs = session.run(None, input_data)
15091507

15101508
optypes = ['LayerNormalization']
@@ -1519,7 +1517,7 @@ def test_fp16(self):
15191517
convert_model = self.get_fp16_mixed_precision_model(model)
15201518
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
15211519
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1522-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1520+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
15231521
outputs = session.run(None, input_data)
15241522

15251523
optypes = ['BatchNormalization']
@@ -1537,7 +1535,7 @@ def test_fp16(self):
15371535
convert_model = self.get_fp16_mixed_precision_model(model)
15381536
self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
15391537
self.assertTrue(10 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
1540-
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
1538+
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=ort.get_available_providers())
15411539
outputs = session.run(None, input_data)
15421540

15431541
def get_bf16_mixed_precision_model(self, model):
@@ -1547,7 +1545,7 @@ def get_bf16_mixed_precision_model(self, model):
15471545
converted_model = fit(model, config)
15481546
return converted_model
15491547

1550-
@unittest.skipIf(not CpuInfo().bf16 or 'DnnlExecutionProvider' not in ort.get_all_providers(),
1548+
@unittest.skipIf(not CpuInfo().bf16 or 'DnnlExecutionProvider' not in ort.get_available_providers(),
15511549
"skip since DnnlExecutionProvider is not supported")
15521550
def test_bf16(self):
15531551
optypes = ['Sum', 'Sub', 'Div', 'Pow', 'Add']
@@ -1665,9 +1663,9 @@ def test_bf16(self):
16651663
optypes = ['Gemm']
16661664
for optype in optypes:
16671665
inps = [['input1', TensorProto.FLOAT, (1,2)]]
1668-
outs = [['output', TensorProto.FLOAT, (1,2)]]
1666+
outs = [['output', TensorProto.FLOAT, (1,1)]]
16691667
weights = [['input2', TensorProto.FLOAT, (2,1), np.random.random((2))],
1670-
['input3', TensorProto.FLOAT, [], np.random.random((1))]]
1668+
['input3', TensorProto.FLOAT, (1,1), np.random.random((1))]]
16711669
node_infos = [['test', ['input1', 'input2', 'input3'], ['output'], optype]]
16721670
model = self.build_model(inps, outs, weights, node_infos)
16731671
input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])

0 commit comments

Comments
 (0)