From d3561bca654a5f8b0d21616b2140388515089728 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Thu, 14 Jan 2021 19:40:42 -0500 Subject: [PATCH] Renaming and updating of code docs (#5) * Renaming and updating of code docs: - nmie -> deepsparse - Model -> Engine - forward -> run * - address comments - refactor helper create_engine to compile_model * address review comments to fix: - more explanatory import error messages - improper benchmarking constraints * - add in mapped_run function --- setup.cfg | 2 +- setup.py | 36 ++- src/.gitignore | 4 +- src/deepsparse/__init__.py | 8 + src/deepsparse/engine.py | 513 +++++++++++++++++++++++++++++++++++++ src/nmie/__init__.py | 1 - src/nmie/model.py | 475 ---------------------------------- 7 files changed, 548 insertions(+), 491 deletions(-) create mode 100644 src/deepsparse/__init__.py create mode 100644 src/deepsparse/engine.py delete mode 100644 src/nmie/__init__.py delete mode 100644 src/nmie/model.py diff --git a/setup.cfg b/setup.cfg index 64fb473de5..19e6200ab2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ default_section = FIRSTPARTY ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True -known_first_party = nmie +known_first_party = deepsparse known_third_party = line_length = 88 diff --git a/setup.py b/setup.py index 12bb2c9722..99592d6c5d 100644 --- a/setup.py +++ b/setup.py @@ -8,11 +8,13 @@ # File regexes for binaries to include in package_data binary_regexes = ["*/*.so", "*/*.so.*", "*.bin", "*/*.bin"] + class OverrideInstall(install): """ This class adds a hook that runs after regular install that changes the permissions of all the binary files to 0755. """ + def run(self): install.run(self) mode = 0o755 @@ -28,20 +30,16 @@ def _setup_package_dir() -> Dict: def _setup_packages() -> List: return find_packages( - "src", include=["nmie", "nmie.*"], exclude=["*.__pycache__.*"] + "src", include=["deepsparse", "deepsparse.*"], exclude=["*.__pycache__.*"] ) def _setup_package_data() -> Dict: - return {"nmie": binary_regexes} + return {"deepsparse": binary_regexes} def _setup_install_requires() -> List: - return [ - "numpy>=1.16.3", - "onnx>=1.5.0,<1.8.0", - "requests>=2.0.0" - ] + return ["numpy>=1.16.3", "onnx>=1.5.0,<1.8.0", "requests>=2.0.0"] def _setup_extras() -> Dict: @@ -57,12 +55,12 @@ def _setup_long_description() -> Tuple[str, str]: setup( - name="nmie", + name="deepsparse", version="0.1.0", - author="Bill Nell, Michael Goin, Mark Kurtz", + author="Bill Nell, Michael Goin, Mark Kurtz, Kevin Rodriguez, Benjamin Fineran", author_email="support@neuralmagic.com", - description="The high performance Neural Magic Inference Engine designed " - "for running deep learning on X86 CPU architectures", + description="The high performance DeepSparse Engine designed to achieve " + "GPU class performance for Neural Networks on commodity CPUs.", long_description=_setup_long_description()[0], long_description_content_type=_setup_long_description()[1], keywords="inference machine learning x86 x86_64 avx2 avx512 neural network", @@ -77,7 +75,21 @@ def _setup_long_description() -> Tuple[str, str]: entry_points=_setup_entry_points(), python_requires=">=3.6.0", classifiers=[ - "[TODO]" + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Programming Language :: Python :: 3", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", ], cmdclass={"install": OverrideInstall}, ) diff --git a/src/.gitignore b/src/.gitignore index 964cde91f8..7a889cfaa4 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1,5 +1,5 @@ -nmie/avx2/ -nmie/avx512/ +deepsparse/avx2/ +deepsparse/avx512/ arch.bin cpu.py version.py diff --git a/src/deepsparse/__init__.py b/src/deepsparse/__init__.py new file mode 100644 index 0000000000..fd2c757aa2 --- /dev/null +++ b/src/deepsparse/__init__.py @@ -0,0 +1,8 @@ +""" +The DeepSparse package used to achieve GPU class performance +for Neural Networks on commodity CPUs. +""" + +# flake8: noqa + +from .engine import * diff --git a/src/deepsparse/engine.py b/src/deepsparse/engine.py new file mode 100644 index 0000000000..d3c0c60599 --- /dev/null +++ b/src/deepsparse/engine.py @@ -0,0 +1,513 @@ +""" +Code related to interfacing with a Neural Network in the DeepSparse Engine using python +""" + +from typing import List, Dict, Optional, Iterable, Tuple +import os +import numpy +import importlib +import time + +try: + from deepsparse.cpu import cpu_details + from deepsparse.version import * +except ImportError: + raise ImportError( + "Unable to import deepsparse python apis. " + "Please contact support@neuralmagic.com" + ) + + +__all__ = ["Engine", "compile_model", "analyze_model"] + + +CORES_PER_SOCKET, AVX_TYPE, VNNI = cpu_details() + + +def _import_ort_nm(): + try: + nm_package_dir = os.path.dirname(os.path.abspath(__file__)) + onnxruntime_neuralmagic_so_path = os.path.join( + nm_package_dir, AVX_TYPE, "neuralmagic_onnxruntime_engine.so" + ) + spec = importlib.util.spec_from_file_location( + "deepsparse.{}.neuralmagic_onnxruntime_engine".format(AVX_TYPE), + onnxruntime_neuralmagic_so_path, + ) + engine = importlib.util.module_from_spec(spec) + spec.loader.exec_module(engine) + + return engine + except ImportError: + raise ImportError( + "Unable to import deepsparse engine binaries. " + "Please contact support@neuralmagic.com" + ) + + +ENGINE_ORT = _import_ort_nm() + + +class Engine(object): + """ + Create a new DeepSparse Engine that compiles the given onnx file + for GPU class performance on commodity CPUs. + + Note 1: Engines are compiled for a specific batch size and + for a specific number of CPU cores. + + Note 2: multi socket support is not yet built in to the Engine, + all execution assumes single socket + + | Example: + | # create an engine for batch size 1 on all available cores + | engine = Engine("path/to/onnx", batch_size=1, num_cores=-1) + + :param onnx_file_path: The local path to the onnx file for the + neural network graph definition to instantiate + :param batch_size: The batch size of the inputs to be used with the engine + :param num_cores: The number of physical cores to run the model on. + Pass -1 to run on the max number of cores in one socket for the current machine + """ + + def __init__(self, onnx_file_path: str, batch_size: int, num_cores: int): + if num_cores == -1: + num_cores = CORES_PER_SOCKET + + if num_cores < 1: + raise ValueError( + "num_cores must be greater than 0: given {}".format(num_cores) + ) + + if not os.path.exists(onnx_file_path): + raise ValueError( + "onnx_file_path must exist: given {}".format(onnx_file_path) + ) + + self._onnx_file_path = onnx_file_path + self._batch_size = batch_size + self._num_cores = num_cores + self._num_sockets = 1 # only single socket is supported currently + self._cpu_avx_type = AVX_TYPE + self._cpu_vnni = VNNI + self._eng_net = ENGINE_ORT.neuralmagic_onnxruntime_engine( + self._onnx_file_path, self._batch_size, self._num_cores, self._num_sockets + ) + + def __call__( + self, inp: List[numpy.ndarray], val_inp: bool = True, + ) -> List[numpy.ndarray]: + """ + Convenience function for Engine.run(), see @run for more details + + | Example: + | engine = Engine("path/to/onnx", batch_size=1, num_cores=-1) + | inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + | out = engine(inp) + | assert isinstance(out, List) + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for the DeepSparse Engine + :return: The list of outputs from the model after executing over the inputs + """ + return self.run(inp, val_inp) + + def __repr__(self): + """ + :return: Unambiguous representation of the current model instance + """ + return "{}({})".format(self.__class__, self._properties_dict()) + + def __str__(self): + """ + :return: Human readable form of the current model instance + """ + formatted_props = [ + "\t{}: {}".format(key, val) for key, val in self._properties_dict().items() + ] + + return "{}.{}:\n{}".format( + self.__class__.__module__, + self.__class__.__name__, + "\n".join(formatted_props), + ) + + @property + def onnx_file_path(self) -> str: + """ + :return: The local path to the onnx file the current instance was compiled from + """ + return self._onnx_file_path + + @property + def batch_size(self) -> int: + """ + :return: The batch size of the inputs to be used with the model + """ + return self._batch_size + + @property + def num_cores(self) -> int: + """ + :return: The number of physical cores the current instance is running on + """ + return self._num_cores + + @property + def num_sockets(self) -> int: + """ + :return: The number of sockets the engine is compiled to run on; + only current support is 1 + """ + return self._num_sockets + + @property + def cpu_avx_type(self) -> str: + """ + :return: The detected cpu avx type that neural magic is running with. + One of {avx2, avx512}. AVX instructions give significant execution speedup + with avx512 > avx2. + """ + return self._cpu_avx_type + + @property + def cpu_vnni(self) -> bool: + """ + :return: True if vnni support was detected on the cpu, False otherwise. + VNNI gives performance benefits for quantized networks. + """ + return self._cpu_vnni + + def run( + self, inp: List[numpy.ndarray], val_inp: bool = True, + ) -> List[numpy.ndarray]: + """ + Run given inputs through the model for inference. + Returns the result as a list of numpy arrays corresponding to + the outputs of the model as defined in the ONNX graph. + + Note 1: the input dimensions must match what is defined in the ONNX graph. + To avoid extra time in memory shuffles, the best use case + is to format both the onnx and the input into channels first format; + ex: [batch, height, width, channels] => [batch, channels, height, width] + + Note 2: the input type for the numpy arrays must match + what is defined in the ONNX graph. + Generally float32 is most common, + but int8 and int16 are used for certain layer and input types + such as with quantized models. + + Note 3: the numpy arrays must be contiguous in memory, + use numpy.ascontiguousarray(array) to fix if not. + + | Example: + | engine = Engine("path/to/onnx", batch_size=1, num_cores=-1) + | inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + | out = engine.run(inp) + | assert isinstance(out, List) + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for the DeepSparse Engine + :return: The list of outputs from the model after executing over the inputs + """ + if val_inp: + self._validate_inputs(inp) + + return self._eng_net.execute_list_out(inp) + + def timed_run( + self, inp: List[numpy.ndarray], val_inp: bool = True + ) -> Tuple[List[numpy.ndarray], float]: + """ + Convenience method for timing a model inference run. + Returns the result as a tuple containing (the outputs from @run, time take) + + See @run for more details. + + | Example: + | engine = Engine("path/to/onnx", batch_size=1, num_cores=-1) + | inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + | out, time = engine.timed_run(inp) + | assert isinstance(out, List) + | assert isinstance(time, float) + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for the DeepSparse Engine + :return: The list of outputs from the model after executing over the inputs + """ + start = time.time() + out = self.run(inp, val_inp) + end = time.time() + + return out, end - start + + def mapped_run( + self, inp: List[numpy.ndarray], val_inp: bool = True, + ) -> Dict[str, numpy.ndarray]: + """ + Run given inputs through the model for inference. + Returns the result as a dictionary of numpy arrays corresponding to + the output names of the model as defined in the ONNX graph. + + Note 1: this function can add some a performance hit in certain cases. + If using, please validate that you do not incur a performance hit + by comparing with the regular run func + + See @run for more details on specific setup for the inputs. + + | Example: + | engine = Engine("path/to/onnx", batch_size=1) + | inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] + | out = engine.mapped_run(inp) + | assert isinstance(out, Dict) + + :param inp: The list of inputs to pass to the engine for inference. + The expected order is the inputs order as defined in the ONNX graph. + :param val_inp: Validate the input to the model to ensure numpy array inputs + are setup correctly for the DeepSparse Engine + :return: The dictionary of outputs from the model after executing + over the inputs + """ + if val_inp: + self._validate_inputs(inp) + + out = self._eng_net.execute(inp) + + return out + + def benchmark_batched( + self, + batched_data: Iterable[List[numpy.ndarray]], + num_iterations: int = 20, + num_warmup_iterations: int = 5, + include_outputs: bool = False, + ) -> Dict[str, float]: + """ + A convenience function for quickly benchmarking the instantiated model + on a give DataLoader in the DeepSparse Engine. + batched_data must already shaped into the proper batch sizes + for use with benchmarking. + After executing, will return the summary statistics for benchmarking. + + :param batched_data: An iterator of input batches to be used for benchmarking. + :param num_iterations: The number of iterations to run benchmarking for. + Default is 20 + :param num_warmup_iterations: T number of iterations to warm up engine before + benchmarking. These executions will not be counted in the benchmark + results that are returned. Useful and recommended to bring + the system to a steady state. Default is 5 + :param include_outputs: If True, outputs from forward passes during benchmarking + will be returned under the 'outputs' key. Default is False + :return: Dictionary of benchmark results including keys batch_stats_ms, + batch_times_ms, and items_per_sec + """ + assert num_iterations >= 1 and num_warmup_iterations >= 0, ( + "num_iterations and num_warmup_iterations must be non negative for " + "benchmarking." + ) + completed_iterations = 0 + batch_times = [] + outputs = [] + + while completed_iterations < num_warmup_iterations + num_iterations: + for batch in batched_data: + # run benchmark + output, batch_time = self.timed_run(batch, val_inp=False) + + # update results + batch_times.append(batch_time) + if include_outputs: + outputs.append(output) + + # update loop + completed_iterations += 1 + if completed_iterations >= num_warmup_iterations + num_iterations: + break + + batch_times = batch_times[num_warmup_iterations:] # remove warmup times + batch_times_ms = [batch_time * 1000 for batch_time in batch_times] + items_per_sec = self.batch_size / numpy.mean(batch_times).item() + + batch_stats_ms = { + "median": numpy.median(batch_times_ms), + "mean": numpy.mean(batch_times_ms), + "std": numpy.std(batch_times_ms), + } + + benchmark_dict = { + "batch_stats_ms": batch_stats_ms, + "batch_times_ms": batch_times_ms, + "items_per_sec": items_per_sec, + } + + if include_outputs: + benchmark_dict["outputs"] = outputs + + return benchmark_dict + + def benchmark( + self, + data: Iterable[List[numpy.ndarray]], + num_iterations: int = 20, + num_warmup_iterations: int = 5, + include_outputs: bool = False, + ) -> Dict[str, float]: + """ + A convenience function for quickly benchmarking the instantiated model + on a given Dataset in the DeepSparse Engine. + The data param must be individual items, the code will batch + these items into the proper shape for the model for use with benchmarking. + After executing, will return the summary statistics for benchmarking. + + :param data: An iterator of input items to be used for benchmarking. + These items will be stacked to create batches of the proper batch_size. + Items will be stacked in order. Will infinitely loop over the number + of items to create the proper batch size and number of batches. + :param num_iterations: The number of iterations to run benchmarking for. + Default is 20 + :param num_warmup_iterations: T number of iterations to warm up engine before + benchmarking. These executions will not be counted in the benchmark + results that are returned. Useful and recommended to bring + the system to a steady state. Default is 5 + :param include_outputs: If True, outputs from forward passes during benchmarking + will be returned under the 'outputs' key. Default is False + :return: Dictionary of benchmark results including keys batch_stats_ms, + batch_times_ms, and items_per_sec + """ + assert num_iterations >= 1 and num_warmup_iterations >= 0, ( + "num_iterations and num_warmup_iterations must be non negative for " + "benchmarking." + ) + + # define data loader + def infinite_data_batcher(): + batch = [] + while True: + for inputs in data: + batch.append(inputs) + if len(batch) == self.batch_size: + # concatenate batch inputs + batch_inputs = [] + for input_idx in range(len(inputs)): + batch_input = [batch_val[input_idx] for batch_val in batch] + batch_inputs.append(numpy.stack(batch_input)) + # yield and reset + yield batch_inputs + batch = [] + + return self.benchmark_batched( + batched_data=infinite_data_batcher(), + num_iterations=num_iterations, + num_warmup_iterations=num_warmup_iterations, + include_outputs=include_outputs, + ) + + def _validate_inputs(self, inp: List[numpy.ndarray]): + if isinstance(inp, str) or not isinstance(inp, List): + raise ValueError("inp must be a list, given {}".format(type(inp))) + + for arr in inp: + if arr.shape[0] != self._batch_size: + raise ValueError( + ( + "array batch size of {} must match the batch size " + "the model was instantiated with {}" + ).format(arr.shape[0], self._batch_size) + ) + + if not arr.flags["C_CONTIGUOUS"]: + raise ValueError( + "array must be passed in as C contiguous, " + "call numpy.ascontiguousarray(array)" + ) + + def _properties_dict(self) -> Dict: + return { + "onnx_file_path": self._onnx_file_path, + "batch_size": self._batch_size, + "num_cores": self._num_cores, + "num_sockets": self._num_sockets, + "cpu_avx_type": self._cpu_avx_type, + "cpu_vnni": self._cpu_vnni, + } + + +def compile_model( + onnx_file_path: str, batch_size: int = 1, num_cores: int = -1 +) -> Engine: + """ + Convenience function to compile a model in the DeepSparse Engine + from an ONNX file for inference. + Gives defaults of batch_size == 1 and num_cores == -1 + (will use all physical cores available on a single socket). + + :param onnx_file_path: The local path to the onnx file for the + neural network graph definition to instantiate + :param batch_size: The batch size of the inputs to be used with the model + :param num_cores: The number of physical cores to run the model on. + Pass -1 to run on the max number of cores in one socket for the current machine + :return: The created Engine after compiling the model + """ + return Engine(onnx_file_path, batch_size, num_cores) + + +def analyze_model( + onnx_file_path: str, + input: List[numpy.ndarray], + batch_size: int = 1, + num_cores: int = -1, + num_iterations: int = 20, + num_warmup_iterations: int = 5, + optimization_level: int = 1, + imposed_as: Optional[float] = None, + imposed_ks: Optional[float] = None, +) -> dict: + """ + Function to analyze a model's performance in the DeepSparse Engine. + The model must be defined in an ONNX graph and stored in a local file. + Gives defaults of batch_size == 1 and num_cores == -1 + (will use all physical cores available on a single socket). + + :param onnx_file_path: The local path to the onnx file for the + neural network graph definition to instantiate + :param batch_size: The batch size of the inputs to be used with the model + :param num_cores: The number of physical cores to run the model on. + Pass -1 to run on the max number of cores in one socket for the current machine + :param num_iterations: The number of times to repeat execution of the model + while analyzing, default is 20 + :param num_warmup_iterations: The number of times to repeat execution of the model + before analyzing, default is 5 + :param optimization_level: The amount of graph optimizations to perform. + The current choices are either 0 (minimal) or 1 (all), default is 1 + :param imposed_as: Imposed activation sparsity, defaults to None. + Will force the activation sparsity from all ReLu layers in the graph + to match this desired sparsity level (percentage of 0's in the tensor). + Beneficial for seeing how AS affects the performance of the model. + :param imposed_ks: Imposed kernel sparsity, defaults to None. + Will force all prunable layers in the graph to have weights with + this desired sparsity level (percentage of 0's in the tensor). + Beneficial for seeing how pruning affects the performance of the model. + :return: the analysis structure containing the performance details of each layer + """ + if num_cores == -1: + num_cores = CORES_PER_SOCKET + num_sockets = 1 # only single socket is supported currently + + eng_net = ENGINE_ORT.neuralmagic_onnxruntime_engine( + onnx_file_path, batch_size, num_cores, num_sockets + ) + + return eng_net.benchmark( + input, + num_iterations, + num_warmup_iterations, + optimization_level, + imposed_as, + imposed_ks, + ) diff --git a/src/nmie/__init__.py b/src/nmie/__init__.py deleted file mode 100644 index 87cb7367b0..0000000000 --- a/src/nmie/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model import * diff --git a/src/nmie/model.py b/src/nmie/model.py deleted file mode 100644 index 9f10b938bb..0000000000 --- a/src/nmie/model.py +++ /dev/null @@ -1,475 +0,0 @@ -""" -code related to interfacing with the Neural Magic inference engine using python -""" - -from typing import List, Dict, Optional, Iterable -import os -import numpy -import importlib -import ctypes -import time -import warnings - -try: - from .cpu import cpu_details - from .version import * -except ImportError: - raise ImportError("Unable to import engine backend.") - - -__all__ = ["Model", "create_model", "analyze_model", "benchmark_model"] - - -CORES_PER_SOCKET, AVX_TYPE, VNNI = cpu_details() - - -def _import_ort_nm(): - try: - nm_package_dir = os.path.dirname(os.path.abspath(__file__)) - onnxruntime_neuralmagic_so_path = os.path.join( - nm_package_dir, AVX_TYPE, "neuralmagic_onnxruntime_engine.so" - ) - spec = importlib.util.spec_from_file_location( - "nmie.{}.neuralmagic_onnxruntime_engine".format(AVX_TYPE), - onnxruntime_neuralmagic_so_path, - ) - engine = importlib.util.module_from_spec(spec) - spec.loader.exec_module(engine) - - return engine - except ImportError: - raise ImportError("Unable to import engine backend.") - - -ENGINE_ORT = _import_ort_nm() - - -class Model(object): - """ - Python Model api for interfacing with the Neural Magic inference engine - """ - - def __init__(self, onnx_file_path: str, batch_size: int, num_cores: int): - """ - Create a new model from the given onnx file in the Neural Magic engine for performance on CPUs. - - Note 1: models are created for a specific batch size and for a specific number of CPU cores. - - Note 2: multi socket support is not yet built in to the inference engine, - all execution assumes single socket - - Example: - # create model for batch size 1 on all available cores - model = Model("path/to/onnx", batch_size=1, num_cores=-1) - - :param onnx_file_path: the local path to the onnx file for the neural network definition to instantiate - :param batch_size: the batch size of the inputs to be used with the model - :param num_cores: the number of physical cores to run the model on, - pass -1 to run on the max number of cores in one socket for current machine - """ - - if num_cores == -1: - num_cores = CORES_PER_SOCKET - - if num_cores < 1: - raise ValueError( - "num_cores must be greater than 0: given {}".format(num_cores) - ) - - if not os.path.exists(onnx_file_path): - raise ValueError( - "onnx_file_path must exist: given {}".format(onnx_file_path) - ) - - self._onnx_file_path = onnx_file_path - self._batch_size = batch_size - self._num_cores = num_cores - self._num_sockets = 1 # only single socket is supported currently - self._cpu_avx_type = AVX_TYPE - self._cpu_vnni = VNNI - self._eng_net = ENGINE_ORT.neuralmagic_onnxruntime_engine( - self._onnx_file_path, self._batch_size, self._num_cores, self._num_sockets - ) - - def __call__( - self, inp: List[numpy.ndarray], val_inp: bool = True, - ) -> List[numpy.ndarray]: - """ - Convenience function for model.forward(), see @forward for more details - - Example: - model = Model("path/to/onnx", batch_size=1, num_cores=-1) - inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] - out = model(inp) - assert isinstance(out, List) - - :param inp: the list of inputs to pass to the model for inference - :param val_inp: validate the input to the model to make sure the numpy array is setup correctly for the engine - :return: the list of outputs from the model after executing over the inputs - """ - return self.forward(inp, val_inp) - - def __repr__(self): - """ - :return: unambiguous representation of the current model instance - """ - return "{}({})".format(self.__class__, self._properties_dict()) - - def __str__(self): - """ - :return: human readable form of the current model instance - """ - - formatted_props = [ - "\t{}: {}".format(key, val) for key, val in self._properties_dict().items() - ] - - return "{}.{}:\n{}".format( - self.__class__.__module__, - self.__class__.__name__, - "\n".join(formatted_props), - ) - - @property - def onnx_file_path(self) -> str: - """ - :return: the local path to the onnx file the model was created from - """ - return self._onnx_file_path - - @property - def batch_size(self) -> int: - """ - :return: the batch size of the inputs to be used with the model - """ - return self._batch_size - - @property - def num_cores(self) -> int: - """ - :return: the number of physical cores to run the model on - """ - return self._num_cores - - @property - def num_sockets(self) -> int: - """ - :return: the number of sockets the engine is compiled to run on, only current support is 1 - """ - return self._num_sockets - - @property - def cpu_avx_type(self) -> str: - """ - :return: the detected cpu avx type that neural magic is running with. One of {avx2, avx512} - """ - return self._cpu_avx_type - - @property - def cpu_vnni(self) -> bool: - """ - :return: True if vnni support was detected on the cpu, False otherwise - """ - return self._cpu_vnni - - def forward( - self, inp: List[numpy.ndarray], val_inp: bool = True, - ) -> List[numpy.ndarray]: - """ - Execute a forward pass through the model for the given input (inference) - Returns the result as a list of numpy arrays corresponding to the outputs of the model as defined in the onnx. - - Note 1: the input dimensions must match what is given in the onnx. - To avoid extra time in memory shuffles, best use case is to format both the onnx and the input into - channels first format; ex: [batch, height, width, channels] => [batch, channels, height, width] - - Note 2: the input type for the numpy arrays must match what is given in the onnx. - Generally float32 is most common, but int8 and int16 are used for certain layer and input types. - - Note 3: the numpy arrays must be contiguous in memory, use numpy.ascontiguousarray(array) to fix if not - - Example: - model = Model("path/to/onnx", batch_size=1, num_cores=-1) - inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] - out = model.forward(inp) - assert isinstance(out, List) - - :param inp: the list of inputs to pass to the model for inference - :param val_inp: validate the input to the model to make sure the numpy array is setup correctly for the engine. - Generally should be left at the default = True - :return: the list of outputs from the model after executing over the inputs - """ - if val_inp: - self._validate_inputs(inp) - - return self._eng_net.execute_list_out(inp) - - def mapped_forward( - self, inp: List[numpy.ndarray], val_inp: bool = True, - ) -> Dict[str, numpy.ndarray]: - """ - Execute a forward pass through the model for the given input (inference) and return a dictionary result. - Each resulting tensor returned will be stored as a mapping in the dictionary. - The keys are strings equal to the name as defined in onnx, the values are the output arrays. - - Note 1: this function can add some a performance hit in certain cases (it involves an extra memory copy) - So, if using, please validate that you do not incur a performance hit by comparing with the regular forward func - - See @forward for more details on specific setup for the inputs. - - Example: - model = Model("path/to/onnx", batch_size=1, num_cores=-1) - inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)] - out = model.mapped_forward(inp) - assert isinstance(out, Dict) - - :param inp: the list of inputs to pass to the model for inference - :param val_inp: validate the input to the model to make sure the numpy array is setup correctly for the engine. - Generally should be left at the default = True - :return: the dictionary of outputs from the model after executing over the inputs - """ - if val_inp: - self._validate_inputs(inp) - - out = self._eng_net.execute(inp) - - return out - - def benchmark_batched( - self, - batched_data: Iterable[List[numpy.ndarray]], - num_iterations: int = 20, - num_warmup_iterations: int = 5, - include_outputs: bool = False, - ) -> Dict[str, float]: - """ - :param batched_data: an iterator of input batches to be used for benchmarking. - :param num_iterations: Number of iterations to run benchmarking for. - Default is 20 - :param num_warmup_iterations: Number of iterations to warm up engine before - benchmarking. Default is 5 - :param include_outputs: if True, outputs from forward passes during benchmarking - will be returned under the 'outputs' key. Default is False - :return: Dictionary of benchmark results including keys batch_stats_ms, - batch_times_ms, and items_per_sec - """ - assert num_iterations >= 0 and num_warmup_iterations >= 0, ( - "num_iterations and num_warmup_iterations must be non negative for " - "benchmarking." - ) - completed_iterations = 0 - batch_times = [] - outputs = [] - while completed_iterations < num_warmup_iterations + num_iterations: - for batch in batched_data: - # run benchmark - batch_time = time.time() - output = self.forward(batch) - batch_time = time.time() - batch_time - # update results - batch_times.append(batch_time) - if include_outputs: - outputs.append(output) - # update loop - completed_iterations += 1 - if completed_iterations >= num_warmup_iterations + num_iterations: - break - batch_times = batch_times[num_warmup_iterations:] # remove warmup times - batch_times_ms = [batch_time * 1000 for batch_time in batch_times] - items_per_sec = self.batch_size / numpy.mean(batch_times).item() - - batch_stats_ms = { - "median": numpy.median(batch_times_ms), - "mean": numpy.mean(batch_times_ms), - "std": numpy.std(batch_times_ms), - } - - benchmark_dict = { - "batch_stats_ms": batch_stats_ms, - "batch_times_ms": batch_times_ms, - "items_per_sec": items_per_sec, - } - - if include_outputs: - benchmark_dict["outputs"] = outputs - - return benchmark_dict - - def benchmark( - self, - data: Iterable[List[numpy.ndarray]], - num_iterations: int = 20, - num_warmup_iterations: int = 5, - include_outputs: bool = False, - ) -> Dict[str, float]: - """ - :param data: an iterator of input data to be used for benchmarking. Should - be single data points with no batch dimension - :param num_iterations: Number of iterations to run benchmarking for. - Default is 20 - :param num_warmup_iterations: Number of iterations to warm up engine before - benchmarking. Default is 5 - :param include_outputs: if True, outputs from forward passes during benchmarking - will be returned under the 'outputs' key. Default is False - :return: Dictionary of benchmark results including keys batch_stats_ms, - batch_times_ms, and items_per_sec - """ - assert num_iterations >= 0 and num_warmup_iterations >= 0, ( - "num_iterations and num_warmup_iterations must be non negative for " - "benchmarking." - ) - - # define data loader - def infinite_data_batcher(): - batch = [] - while True: - for inputs in data: - batch.append(inputs) - if len(batch) == self.batch_size: - # concatenate batch inputs - batch_inputs = [] - for input_idx in range(len(inputs)): - batch_input = [batch_val[input_idx] for batch_val in batch] - batch_inputs.append(numpy.stack(batch_input)) - # yield and reset - yield batch_inputs - batch = [] - - return self.benchmark_batched( - batched_data=infinite_data_batcher(), - num_iterations=num_iterations, - num_warmup_iterations=num_warmup_iterations, - include_outputs=include_outputs, - ) - - def _validate_inputs(self, inp: List[numpy.ndarray]): - if isinstance(inp, str) or not isinstance(inp, List): - raise ValueError("inp must be a list, given {}".format(type(inp))) - - for arr in inp: - if arr.shape[0] != self._batch_size: - raise ValueError( - "array batch size of {} must match the batch size the model was instantiated with {}".format( - arr.shape[0], self._batch_size - ) - ) - - if not arr.flags["C_CONTIGUOUS"]: - raise ValueError( - "array must be passed in as C contiguous, call numpy.ascontiguousarray(array)" - ) - - def _properties_dict(self) -> Dict: - return { - "onnx_file_path": self._onnx_file_path, - "batch_size": self._batch_size, - "num_cores": self._num_cores, - "num_sockets": self._num_sockets, - "cpu_avx_type": self._cpu_avx_type, - "cpu_vnni": self._cpu_vnni, - } - - -def create_model( - onnx_file_path: str, batch_size: int = 1, num_cores: int = -1 -) -> Model: - """ - Convenience function to create a model in the Neural Magic engine from an onnx file for inference. - Gives defaults of batch_size == 1 and num_cores == -1 (will use all physical cores available on a single socket) - - :param onnx_file_path: the local path to the onnx file for the neural network definition to instantiate - :param batch_size: the batch size of the inputs to be used with the model, default is 1 - :param num_cores: the number of physical cores to run the model on, default is -1 (detect physical cores num) - :return: the created model - """ - return Model(onnx_file_path, batch_size, num_cores) - - -def analyze_model( - onnx_file_path: str, - input: List[numpy.ndarray], - batch_size: int = 1, - num_cores: int = -1, - num_iterations: int = 1, - num_warmup_iterations: int = 0, - optimization_level: int = 1, - imposed_as: Optional[float] = None, - imposed_ks: Optional[float] = None, -) -> dict: - """ - Function to analyze a model's performance using the Neural Magic engine from an onnx file for inference. - Gives defaults of batch_size == 1 and num_cores == -1 (will use all physical cores available on a single socket) - - :param onnx_file_path: the local path to the onnx file for the neural network definition to instantiate - :param input: the list of inputs to pass to the model for benchmarking - :param batch_size: the batch size of the inputs to be used with the model, default is 1 - :param num_cores: the number of physical cores to run the model on, default is -1 (detect physical cores num) - :param num_iterations: number of times to repeat execution, default is 1 - :param num_warmup_iterations: number of times to repeat unrecorded before starting actual benchmarking iterations - :param optimization_level: how much optimization to perform?, default is 1 - :param imposed_as: imposed activation sparsity, defaults to None - :param imposed_ks: imposed kernel sparsity, defaults to None - - :return: the analysis structure containing the performance details of each layer - """ - if num_cores == -1: - num_cores = CORES_PER_SOCKET - num_sockets = 1 # only single socket is supported currently - - eng_net = ENGINE_ORT.neuralmagic_onnxruntime_engine( - onnx_file_path, batch_size, num_cores, num_sockets - ) - - return eng_net.benchmark( - input, - num_iterations, - num_warmup_iterations, - optimization_level, - imposed_as, - imposed_ks, - ) - - -def benchmark_model( - onnx_file_path: str, - input: List[numpy.ndarray], - batch_size: int = 1, - num_cores: int = -1, - num_iterations: int = 1, - num_warmup_iterations: int = 0, - optimization_level: int = 1, - imposed_as: Optional[float] = None, - imposed_ks: Optional[float] = None, -) -> dict: - """ - DEPRECATED: Use nmie.analyze_model instead - Function to analyze a model's performance using the Neural Magic engine from an onnx file for inference. - Gives defaults of batch_size == 1 and num_cores == -1 (will use all physical cores available on a single socket) - - :param onnx_file_path: the local path to the onnx file for the neural network definition to instantiate - :param input: the list of inputs to pass to the model for benchmarking - :param batch_size: the batch size of the inputs to be used with the model, default is 1 - :param num_cores: the number of physical cores to run the model on, default is -1 (detect physical cores num) - :param num_iterations: number of times to repeat execution, default is 1 - :param num_warmup_iterations: number of times to repeat unrecorded before starting actual benchmarking iterations - :param optimization_level: how much optimization to perform?, default is 1 - :param imposed_as: imposed activation sparsity, defaults to None - :param imposed_ks: imposed kernel sparsity, defaults to None - - :return: the analysis structure containing the performance details of each layer - """ - warnings.warn( - "Use of nmie.benchmark_model is deprecated. " - "Use nmie.analyze_model instead." - ) - return analyze_model( - onnx_file_path, - input, - batch_size, - num_cores, - num_iterations, - num_warmup_iterations, - optimization_level, - imposed_as, - imposed_ks, - )