Skip to content

Commit 70a5662

Browse files
authored
fix domain detection for large model (#565)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent ef928e9 commit 70a5662

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,12 @@ def _detect_domain(self, model):
636636
# 2. according to input
637637
# typically, NLP models have multiple inputs,
638638
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
639-
sess = ort.InferenceSession(model.model.SerializeToString())
639+
if not model.is_large_model:
640+
sess = ort.InferenceSession(model.model.SerializeToString())
641+
elif model.model_path is not None: # pragma: no cover
642+
sess = ort.InferenceSession(model.model_path)
643+
else: # pragma: no cover
644+
assert False, "Please use model path instead of onnx model object to quantize."
640645
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
641646
if len(input_shape_lens) > 1 and all(shape_len == 2 for shape_len in input_shape_lens):
642647
is_nlp = True

0 commit comments

Comments
 (0)