Skip to content

Commit 295535a

Browse files
committed
enhance onnxrt backend setting
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent a16332d commit 295535a

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def __init__(self, framework_specific_info):
8383
self.format = "integerops"
8484
if "format" in framework_specific_info and framework_specific_info["format"].lower() == "qdq":
8585
logger.warning("Dynamic approach doesn't support QDQ format.")
86+
87+
# do not load TensorRT if backend is not TensorrtExecutionProvider
88+
if self.backend != "TensorrtExecutionProvider":
89+
os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"
8690

8791
# get quantization config file according to backend
8892
config_file = None
@@ -700,9 +704,9 @@ def _detect_domain(self, model):
700704
# typically, NLP models have multiple inputs,
701705
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
702706
if not model.is_large_model:
703-
sess = ort.InferenceSession(model.model.SerializeToString(), providers=ort.get_available_providers())
707+
sess = ort.InferenceSession(model.model.SerializeToString(), providers=["CPUExecutionProvider"])
704708
elif model.model_path is not None: # pragma: no cover
705-
sess = ort.InferenceSession(model.model_path, providers=ort.get_available_providers())
709+
sess = ort.InferenceSession(model.model_path, providers=["CPUExecutionProvider"])
706710
else: # pragma: no cover
707711
assert False, "Please use model path instead of onnx model object to quantize."
708712
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]

neural_compressor/model/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def _is_onnxruntime(model):
8383

8484
so.register_custom_ops_library(get_library_path())
8585
if isinstance(model, str):
86-
ort.InferenceSession(model, so, providers=ort.get_available_providers())
86+
ort.InferenceSession(model, so, providers=["CPUExecutionProvider"])
8787
else:
88-
ort.InferenceSession(model.SerializeToString(), so, providers=ort.get_available_providers())
88+
ort.InferenceSession(model.SerializeToString(), so, providers=["CPUExecutionProvider"])
8989
except Exception as e: # pragma: no cover
9090
if "Message onnx.ModelProto exceeds maximum protobuf size of 2GB" in str(e):
9191
logger.warning("Please use model path instead of onnx model object to quantize")

test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,6 +1657,15 @@ def test_backend(self, mock_warning):
16571657

16581658
self.assertEqual(mock_warning.call_count, 2)
16591659

1660+
def test_cuda_ep_env_set(self):
1661+
config = PostTrainingQuantConfig(approach="static", backend="onnxrt_cuda_ep", device="gpu", quant_level=1)
1662+
q_model = quantization.fit(
1663+
self.distilbert_model,
1664+
config,
1665+
calib_dataloader=DummyNLPDataloader_dict("distilbert-base-uncased-finetuned-sst-2-english")
1666+
)
1667+
self.assertEqual(os.environ.get("ORT_TENSORRT_UNAVAILABLE"), "1")
1668+
16601669

16611670
if __name__ == "__main__":
16621671
unittest.main()

0 commit comments

Comments
 (0)