From 922b7cc387ab361cfc9105a256bdaf54d7796601 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Wed, 2 Aug 2023 14:01:36 -0700 Subject: [PATCH] Add Bert tokenizer in the supported model list and code refinement (#503) * Add Bert tokenizer in the supported model list and the related code refinement * utest fix --- onnxruntime_extensions/__init__.py | 36 +++++-- onnxruntime_extensions/_cuops.py | 4 +- onnxruntime_extensions/_extensions_pydll.pyi | 4 + onnxruntime_extensions/_hf_cvt.py | 98 +++++++++++++------- onnxruntime_extensions/_ocos.py | 2 +- onnxruntime_extensions/_ortapi2.py | 34 ++++--- onnxruntime_extensions/_torch_cvt.py | 2 +- test/test_audio_codec.py | 3 +- test/test_audio_signal.py | 1 - test/test_autotokenizer.py | 97 ++++--------------- test/test_bert_tokenizer.py | 1 + test/test_bert_tokenizer_op.py | 1 - test/test_whisper.py | 88 ++++++++++++++++++ 13 files changed, 229 insertions(+), 142 deletions(-) create mode 100644 test/test_whisper.py diff --git a/onnxruntime_extensions/__init__.py b/onnxruntime_extensions/__init__.py index 10e815b01..3554ad334 100644 --- a/onnxruntime_extensions/__init__.py +++ b/onnxruntime_extensions/__init__.py @@ -4,24 +4,42 @@ ############################################################################### """ -The entry point to onnxruntime-extensions package. +The `onnxruntime-extensions` Python package offers an API that allows users to generate models for pre-processing and +post-processing tasks. In addition, it also provides an API to register custom operations implemented in Python. +This enables more flexibility and control over model execution, thus expanding the functionality of the ONNX Runtime. """ __author__ = "Microsoft" +__all__ = [ + 'gen_processing_models', + 'get_library_path', + 'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp', + 'enable_py_op', + 'expand_onnx_inputs', + 'hook_model_op', + 'default_opset_domain', + 'OrtPyFunction', 'PyOrtFunction', + 'optimize_model', + 'make_onnx_model', + 'ONNXRuntimeError', + 'hash_64', + '__version__', +] from ._version import __version__ -from ._ocos import get_library_path # noqa -from ._ocos import Opdef, PyCustomOpDef # noqa -from ._ocos import hash_64 # noqa -from ._ocos import enable_py_op # noqa -from ._ocos import expand_onnx_inputs # noqa -from ._ocos import hook_model_op # noqa -from ._ocos import default_opset_domain # noqa -from ._cuops import * # noqa +from ._ocos import get_library_path +from ._ocos import Opdef, PyCustomOpDef +from ._ocos import hash_64 +from ._ocos import enable_py_op +from ._ocos import expand_onnx_inputs +from ._ocos import hook_model_op +from ._ocos import default_opset_domain +from ._cuops import * # noqa from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility from ._ortapi2 import OrtPyFunction, optimize_model, make_onnx_model, ONNXRuntimeError from .cvt import gen_processing_models +# rename the implementation with a more formal name onnx_op = Opdef.declare PyOp = PyCustomOpDef diff --git a/onnxruntime_extensions/_cuops.py b/onnxruntime_extensions/_cuops.py index 330071032..8e57a6a63 100644 --- a/onnxruntime_extensions/_cuops.py +++ b/onnxruntime_extensions/_cuops.py @@ -253,7 +253,9 @@ def get_outputs(cls): def serialize_attr(cls, attrs): attrs_data = {} for k_, v_ in attrs.items(): - if k_ == 'vocab_file': + if k_ == 'vocab': + attrs_data['vocab_file'] = v_ + elif k_ == 'vocab_file': with open(v_, "r", encoding='utf-8') as model_file: lines = model_file.readlines() attrs_data[k_] = '\n'.join(lines) diff --git a/onnxruntime_extensions/_extensions_pydll.pyi b/onnxruntime_extensions/_extensions_pydll.pyi index baf714721..595313fca 100644 --- a/onnxruntime_extensions/_extensions_pydll.pyi +++ b/onnxruntime_extensions/_extensions_pydll.pyi @@ -2,6 +2,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. ############################################################################### +from typing import Callable + class PyCustomOpDef: undefined: int = ... @@ -21,6 +23,8 @@ class PyCustomOpDef: dt_complex64: int = ... dt_complex128: int = ... dt_bfloat16: int = ... + def install_hooker(self, invocation_handler: Callable) -> None: + ... ... diff --git a/onnxruntime_extensions/_hf_cvt.py b/onnxruntime_extensions/_hf_cvt.py index f2a5fe4b8..198d9b386 100644 --- a/onnxruntime_extensions/_hf_cvt.py +++ b/onnxruntime_extensions/_hf_cvt.py @@ -9,9 +9,9 @@ import json import onnx -import numpy as np +from numpy import array as nparray from functools import partial -from collections import namedtuple +from collections import namedtuple, OrderedDict from ._cuops import CustomOpConverter, SingleOpGraph from .util import read_file @@ -32,6 +32,25 @@ def bpe_tokenizer(self, **kwargs): attrs.update(**kwargs) return attrs + def bert_tokenizer(self, **kwargs): + hf_bert_tokenizer = self.tokenizer + # has to be sorted since the id of token was generated automatically. + ordered_vocab = OrderedDict(sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1]))) + vocab = '\n'.join(ordered_vocab.keys()) + attrs = dict(vocab=vocab) + init_kwargs = hf_bert_tokenizer.init_kwargs + attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get('do_lower_case') else 0 + attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get('strip_accents') else 0 + attrs.update(**kwargs) + return attrs + + def bert_decoder(self, **kwargs): + hf_bert_tokenizer = self.tokenizer + attrs = {'vocab': json.dumps( + hf_bert_tokenizer.ids_to_tokens, separators=(',', ':'))} + attrs.update(**kwargs) + return attrs + def bpe_decoder(self, **kwargs): decoder = self.tokenizer.decoder id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)]) @@ -95,22 +114,28 @@ def spm_decoder(self, **kwargs): "default_inputs"], defaults=(None, None, None, None, None)) -# fmt: off +# @formatter:off _PROCESSOR_DICT = { - "GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer, - 'BpeDecoder', HFTokenizerConverter.bpe_decoder), - "ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer, - 'BpeDecoder', HFTokenizerConverter.bpe_decoder), - "RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer, - None, None), - "T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer, - "SentencepieceDecoder", HFTokenizerConverter.spm_decoder, + "BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer, + 'BertDecoder', HFTokenizerConverter.bpe_decoder, None), + "DistilBertTokenizer": + TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer, + 'BertDecoder', HFTokenizerConverter.bpe_decoder, None), + "GPT2Tokenizer": TokenOpParam('Gpt2Tokenizer', HFTokenizerConverter.bpe_tokenizer, + 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None), + "ClipTokenizer": TokenOpParam('ClipTokenizer', HFTokenizerConverter.clip_tokenizer, + 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None), + "RobertaTokenizer": TokenOpParam("RobertaTokenizer", HFTokenizerConverter.roberta_tokenizer, + None, None, None), + "T5Tokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer, + "SentencepieceDecoder", HFTokenizerConverter.spm_decoder, default_inputs={'add_eos': [True]}), - "LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer, - "SentencepieceDecoder", HFTokenizerConverter.spm_decoder, + "LlamaTokenizer": TokenOpParam("SentencepieceTokenizer", HFTokenizerConverter.spm_tokenizer, + "SentencepieceDecoder", HFTokenizerConverter.spm_decoder, default_inputs={'add_bos': [True]}), } -# fmt: on +# @formatter:on + class HFTokenizerOnnxGraph: @@ -137,31 +162,34 @@ def pre_processing(self, **kwargs): _cvt_func = self.cvt_quadruple.pre_attribute_cvt cvt = partial(_cvt_func, self.cvt_obj) g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs) + default_inputs = [] if with_default_inputs: op_class = SingleOpGraph.get_op_class(_cvt_op) default_inputs = op_class.input_default_values() if default_inputs is None: - raise ValueError("The op {} doesn't define default inputs".format(_cvt_op)) - n_inputs = len(default_inputs) - if self.cvt_quadruple.default_inputs is not None: - default_inputs.update(self.cvt_quadruple.default_inputs) - if len(default_inputs) != n_inputs: - raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op)) - - new_initializers = [] - - for k, v in default_inputs.items(): - input_value_info = next((i for i in g.input if i.name == k), None) - if input_value_info is None: - raise ValueError("The input {} is not found in the graph".format(k)) - - np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type) - value = np.array(v, np_dtype) - new_initializers.append(onnx.numpy_helper.from_array(value, k)) - g.initializer.extend(new_initializers) - new_inputs = [i for i in g.input if i.name not in default_inputs] - g.ClearField("input") - g.input.extend(new_inputs) + return g + + # add default_inputs into initializers to simplify the model input + n_inputs = len(default_inputs) + if self.cvt_quadruple.default_inputs is not None: + default_inputs.update(self.cvt_quadruple.default_inputs) + if len(default_inputs) != n_inputs: + raise ValueError("Op: {} does have the inputs from its TokenOpParam.".format(_cvt_op)) + + new_initializers = [] + + for k, v in default_inputs.items(): + input_value_info = next((i for i in g.input if i.name == k), None) + if input_value_info is None: + raise ValueError("The input {} is not found in the graph".format(k)) + + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_value_info.type.tensor_type.elem_type) + value = nparray(v, np_dtype) + new_initializers.append(onnx.numpy_helper.from_array(value, k)) + g.initializer.extend(new_initializers) + new_inputs = [i for i in g.input if i.name not in default_inputs] + g.ClearField("input") + g.input.extend(new_inputs) return g def post_processing(self, **kwargs): diff --git a/onnxruntime_extensions/_ocos.py b/onnxruntime_extensions/_ocos.py index 6f4184ecb..952363336 100644 --- a/onnxruntime_extensions/_ocos.py +++ b/onnxruntime_extensions/_ocos.py @@ -17,7 +17,7 @@ def get_library_path(): """ The custom operator library binary path - :return: A string of the this library path. + :return: A string of this library path. """ mod = sys.modules['onnxruntime_extensions._extensions_pydll'] return mod.__file__ diff --git a/onnxruntime_extensions/_ortapi2.py b/onnxruntime_extensions/_ortapi2.py index fb9d99956..3c54b383b 100644 --- a/onnxruntime_extensions/_ortapi2.py +++ b/onnxruntime_extensions/_ortapi2.py @@ -11,11 +11,11 @@ from ._ocos import default_opset_domain, get_library_path # noqa from ._cuops import onnx, onnx_proto, SingleOpGraph - _ort_check_passed = False try: from packaging import version as _ver import onnxruntime as _ort + if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"): _ort_check_passed = True except ImportError: @@ -37,6 +37,7 @@ def get_opset_version_from_ort(): "1.12": 17, "1.13": 17, "1.14": 18, + "1.15": 18 } ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) @@ -59,6 +60,13 @@ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), class OrtPyFunction: + """ + OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession, + equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a + standard Python function. The order of the function arguments correlates directly with + the sequence of the input/output in the ONNX graph. + """ + def get_ort_session_options(self): so = _ort.SessionOptions() for k, v in self.extra_session_options.items(): @@ -66,7 +74,7 @@ def get_ort_session_options(self): so.register_custom_ops_library(get_library_path()) return so - def __init__(self, cpu_only=None): + def __init__(self, path_or_model=None, cpu_only=None): self._onnx_model = None self.ort_session = None self.default_inputs = {} @@ -75,6 +83,14 @@ def __init__(self, cpu_only=None): if _ort.get_device() == 'GPU': self.execution_providers = ['CUDAExecutionProvider'] self.extra_session_options = {} + mpath = None + if isinstance(path_or_model, str): + oxml = onnx.load_model(path_or_model) + mpath = path_or_model + else: + oxml = path_or_model + if path_or_model is not None: + self._bind(oxml, mpath) def create_from_customop(self, op_type, *args, **kwargs): graph = SingleOpGraph.build_graph(op_type, *args, **kwargs) @@ -130,17 +146,13 @@ def _get_kwarg_device(kwargs): @classmethod def from_customop(cls, op_type, *args, **kwargs): - return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs) + return (cls(cpu_only=cls._get_kwarg_device(kwargs)) + .create_from_customop(op_type, *args, **kwargs)) @classmethod def from_model(cls, path_or_model, *args, **kwargs): - mpath = None - if isinstance(path_or_model, str): - oxml = onnx.load_model(path_or_model) - mpath = path_or_model - else: - oxml = path_or_model - return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath) + fn = cls(path_or_model, cls._get_kwarg_device(kwargs)) + return fn def _argument_map(self, *args, **kwargs): idx = 0 @@ -169,7 +181,7 @@ def __call__(self, *args, **kwargs): def optimize_model(model_or_file, output_file): - sess_options = OrtPyFunction.get_ort_session_options() + sess_options = OrtPyFunction().get_ort_session_options() sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC sess_options.optimized_model_filepath = output_file _ort.InferenceSession(model_or_file if isinstance(model_or_file, str) diff --git a/onnxruntime_extensions/_torch_cvt.py b/onnxruntime_extensions/_torch_cvt.py index eb2ce7217..113900939 100644 --- a/onnxruntime_extensions/_torch_cvt.py +++ b/onnxruntime_extensions/_torch_cvt.py @@ -241,6 +241,6 @@ def post_processing(self, **kwargs): inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] del g.input[:] g.input.extend(inputs) - g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'seq_len', 'text'])) + g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text'])) return make_onnx_model(g, opset_version=self.opset_version) diff --git a/test/test_audio_codec.py b/test/test_audio_codec.py index 6aa829c3a..15acbe114 100644 --- a/test/test_audio_codec.py +++ b/test/test_audio_codec.py @@ -50,7 +50,8 @@ def test_mp3_decoder(self): def test_decoder_resampling(self): test_file = util.get_test_data_file('data', 'jfk.flac') blob = bytearray(util.read_file(test_file, mode='rb')) - decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1) + decoder = PyOrtFunction.from_customop( + 'AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1) pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,))) self.assertEqual(pcm_tensor.shape, (1, 176000)) diff --git a/test/test_audio_signal.py b/test/test_audio_signal.py index 795d813f7..23e919dc2 100644 --- a/test/test_audio_signal.py +++ b/test/test_audio_signal.py @@ -8,7 +8,6 @@ import onnx from onnx import onnx_pb as onnx_proto - _is_torch_available = False try: import torch diff --git a/test/test_autotokenizer.py b/test/test_autotokenizer.py index 26c95708a..6cec0b927 100644 --- a/test/test_autotokenizer.py +++ b/test/test_autotokenizer.py @@ -1,106 +1,41 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import unittest import numpy as np -import onnxruntime as _ort -from packaging import version -from transformers import AutoTokenizer, WhisperProcessor -from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models +from transformers import AutoTokenizer +from onnxruntime_extensions import OrtPyFunction, gen_processing_models -@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0") class TestAutoTokenizer(unittest.TestCase): + def test_bert_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors='np') + ort_tok = OrtPyFunction(gen_processing_models(tokenizer, pre_kwargs={})[0]) + actual_ids = ort_tok([text])[0] + np.testing.assert_array_equal(encoded_input['input_ids'][0], actual_ids) + def test_llama_tokenizer(self): # replace the official model name after the model is not gated anymore tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - ids = tokenizer.encode("I was born in 92000, and this is falsé.", return_tensors="np") + text = "I was born in 92000, and this is falsé." + ids = tokenizer.encode(text, return_tensors="np") ort_tok = OrtPyFunction.from_model(gen_processing_models( tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0]) - actual_ids = ort_tok(["I was born in 92000, and this is falsé."])[0] + actual_ids = ort_tok([text])[0] np.testing.assert_array_equal(ids[0], actual_ids) def test_t5_tokenizer(self): tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) - ids = tokenizer.encode("best hotel in bay area.", return_tensors="np") + text = "best hotel in bay area." + ids = tokenizer.encode(text, return_tensors="np") ort_tok = OrtPyFunction.from_model(gen_processing_models(tokenizer, pre_kwargs={})[0]) - actual_ids = ort_tok(["best hotel in bay area."])[0] + actual_ids = ort_tok([text])[0] np.testing.assert_array_equal(ids[0], actual_ids) - def test_whisper_overall(self): - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - pre_m, post_m = gen_processing_models(processor, - pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False}, - post_kwargs={}) - - fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0}) - t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32) - simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0) - log_mel = fn_pre(simaudio) - - self.assertEqual(log_mel.shape, (1, 80, 3000)) - - fn_post = OrtPyFunction.from_model(post_m) - rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :]) - self.assertEqual(rel[0], "$%&") - - def test_whisper_audio_decoder(self): - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - pre_m, _ = gen_processing_models(processor, - pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True}) - - fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0}) - test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac') - audio_data = np.fromfile(test_flac_file, dtype=np.uint8) - log_mel = fn_pre(np.expand_dims(audio_data, axis=0)) - - self.assertEqual(log_mel.shape, (1, 80, 3000)) - - @unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.") - def test_ort_stft_consistency(self): - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - pre_m, _ = gen_processing_models(processor, - pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True}) - - test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3') - test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0) - raw_audio = OrtPyFunction.from_customop( - "AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data) - - input_features = processor([raw_audio[0]], sampling_rate=16000) - expected = input_features['input_features'][0] - - log_mel = OrtPyFunction.from_model(pre_m)(raw_audio) - actual = log_mel[0] - - num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05)) - # ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%. - self.assertTrue(num_mismatched / np.size(expected) < 0.02) - self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05) - - @unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.") - def test_stft_norm_consistency(self): - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") - pre_m, _ = gen_processing_models(processor, - pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False}) - - test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3') - test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0) - raw_audio = OrtPyFunction.from_customop( - "AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data) - - input_features = processor([raw_audio[0]], sampling_rate=16000) - expected = input_features['input_features'][0] - - log_mel = OrtPyFunction.from_model(pre_m)(raw_audio) - actual = log_mel[0] - - np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05) - self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05) - if __name__ == '__main__': unittest.main() diff --git a/test/test_bert_tokenizer.py b/test/test_bert_tokenizer.py index 20067b3cb..0dead685e 100644 --- a/test/test_bert_tokenizer.py +++ b/test/test_bert_tokenizer.py @@ -114,5 +114,6 @@ def test_text_to_case1(self): print("\n*** Offset mapping tests complete. ***\n") + if __name__ == "__main__": unittest.main() diff --git a/test/test_bert_tokenizer_op.py b/test/test_bert_tokenizer_op.py index f3e982436..8e6b4247b 100644 --- a/test/test_bert_tokenizer_op.py +++ b/test/test_bert_tokenizer_op.py @@ -55,7 +55,6 @@ def test_text_to_case1_with_vocab_file(self): input="cat isnot playing toyssss" ) - def test_text_to_case1_with_hf_tok(self): ort_tok = pnp.PreHuggingFaceBert(hf_tok=_bert_cased_tokenizer) model = pnp.export(pnp.SequentialProcessingModule(ort_tok), ["whatever"], opset_version=12) diff --git a/test/test_whisper.py b/test/test_whisper.py new file mode 100644 index 000000000..042c74e6d --- /dev/null +++ b/test/test_whisper.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import sys +import unittest + +import numpy as np +import onnxruntime as _ort +from packaging import version +from transformers import WhisperProcessor +from onnxruntime_extensions import OrtPyFunction, util, gen_processing_models + + +@unittest.skipIf(version.parse(_ort.__version__) < version.parse("1.14.0"), "skip for onnxruntime < 1.14.0") +class TestHuggingfaceWhisper(unittest.TestCase): + def test_whisper_overall(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + pre_m, post_m = gen_processing_models(processor, + pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False}, + post_kwargs={}) + + fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0}) + t = np.linspace(0, 2 * np.pi, 480000).astype(np.float32) + simaudio = np.expand_dims(np.sin(2 * np.pi * 100 * t), axis=0) + log_mel = fn_pre(simaudio) + + self.assertEqual(log_mel.shape, (1, 80, 3000)) + + fn_post = OrtPyFunction.from_model(post_m) + rel = fn_post(np.asarray([3, 4, 5], dtype=np.int32)[np.newaxis, np.newaxis, :]) + self.assertEqual(rel[0], "$%&") + + def test_whisper_audio_decoder(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + pre_m, _ = gen_processing_models(processor, + pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True}) + + fn_pre = OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0}) + test_flac_file = util.get_test_data_file('data', '1272-141231-0002.flac') + audio_data = np.fromfile(test_flac_file, dtype=np.uint8) + log_mel = fn_pre(np.expand_dims(audio_data, axis=0)) + + self.assertEqual(log_mel.shape, (1, 80, 3000)) + + @unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.") + def test_ort_stft_consistency(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + pre_m, _ = gen_processing_models(processor, + pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": True}) + + test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3') + test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0) + raw_audio = OrtPyFunction.from_customop( + "AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data) + + input_features = processor([raw_audio[0]], sampling_rate=16000) + expected = input_features['input_features'][0] + + log_mel = OrtPyFunction.from_model(pre_m)(raw_audio) + actual = log_mel[0] + + num_mismatched = np.sum(~np.isclose(expected, actual, rtol=1e-03, atol=1e-05)) + # ORT STFT has a few more mismatched values than HuggingFace's WhisperProcessor, around 1.5%. + self.assertTrue(num_mismatched / np.size(expected) < 0.02) + self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05) + + @unittest.skipIf(sys.platform.startswith('win'), "Huggingface Processor crashed on Windows.") + def test_stft_norm_consistency(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + pre_m, _ = gen_processing_models(processor, + pre_kwargs={"USE_AUDIO_DECODER": False, "USE_ONNX_STFT": False}) + + test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3') + test_data = np.expand_dims(np.fromfile(test_mp3_file, dtype=np.uint8), axis=0) + raw_audio = OrtPyFunction.from_customop( + "AudioDecoder", cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)(test_data) + + input_features = processor([raw_audio[0]], sampling_rate=16000) + expected = input_features['input_features'][0] + + log_mel = OrtPyFunction.from_model(pre_m)(raw_audio) + actual = log_mel[0] + + np.testing.assert_allclose(expected, actual, rtol=1e-03, atol=1e-05) + self.assertAlmostEqual(expected.min(), actual.min(), delta=1e-05) + + +if __name__ == "__main__": + unittest.main()