Welcome to ONNX Rustime, a Rust-inspired ONNX runtime.
This project aims to provide a robust and efficient ONNX runtime experience with the power and safety of the Rust programming language.
- π Features
- β³ How to install & use
- π Project Structure
- π ONNX Rustime Core Structures
- π ONNX Parser: Serialize and Deserialize models & data
- π Running the ONNX Network with ONNX Rustime
- π οΈ Operations in ONNX Rustime
- π Automatic Data Preprocessing
- π Python-Rust Binding with ONNX Rustime
- π’ JavaScript-Rust Binding with ONNX Rustime
- π©βπ» Contribution
- π¦ Rust-inspired ONNX runtime: Experience the power and safety of Rust combined with the flexibility of ONNX.
- π Robust parser for ONNX files & test data: Deserialize ONNX files and their associated test data with ease and accuracy.
- π Run network inference post-parsing: Once your ONNX files are parsed, run network inferences seamlessly.
- π¨ Scaffold for adding operations: You can easily extend this runtime with additional operations, so that more networks are supported.
- π Demo-ready with multiple CNNs: The project comes with multiple convolutional neural networks ready for demonstration. Feel free to extend and experiment.
- π Supports batching for simultaneous inferences: Run multiple inferences simultaneously with the batching feature, leveraging
rayon
for parallelization of batches. - πΎ Serialize models & data with ease: Serialization made easy for both models and data.
- π Seamless Python and JavaScript integration via Rust bindings: Integrate with Python and JavaScript effortlessly using the provided Rust bindings.
ONNX Rustime provides support for a set of pre-trained convolutional neural networks (CNNs) from the vision / classification
section of the ONNX Model Zoo:
Model Name | Description |
---|---|
bvlcalexnet-12 | A variant of the renowned AlexNet architecture. |
caffenet-12 | Inspired by the AlexNet model, optimized for the Caffe framework. |
mnist-8 | Tailored for the MNIST dataset, specializing in handwritten digit recognition. |
resnet152-v2-7 | A deep model from the ResNet family with 152 layers. |
squeezenet1.0-12 | Lightweight model known for its efficiency while maintaining competitive accuracy. |
zfnet512-12 | An improved model based on the AlexNet architecture. |
All these models, along with their test data, are available for download from the ONNX Model Zoo. They can be easily integrated into ONNX Rustime for swift and accurate inferences.
The test data consists generally of 2 files:
input.pb
: the input of the networkoutput.pb
: the expected output of the network
For simplicity, the inputs and the expected outputs of the provided example networks are hardcoded in the display.rs
, which manages the command line interface. This can be easily extended for other inputs & expected outputs.
Clone this repository and run it
git clone https://github.com/GLorenzo679/PDS-Project.git
cd <your root folder>
cargo run
You should be greeted with this interactive menu After selecting one of the supported networks you will be shown a series of sub-menus.
- First, you will be asked whether to save the output data or not.
- If you select to save the data, you can choose a path to save your output. The output data will be saved as a serialized
.pb
file, which is the same format of the test data provided with the models.
Eventually, you will be asked whether to run the network in verbose mode or not. Compare a non verbose execution with a verbose execution.
To set up the Python environment for Rust bindings, follow these steps:
-
Create a virtual environment from the root folder of the project:
cd <your root folder> python3 -m venv ./py_onnx_rustime/rust-binding/
-
Activate the virtual environment:
- For Linux and macOS:
source py_onnx_rustime/rust-binding/bin/activate
- For Windows:
.\py_onnx_rustime\rust-binding\Scripts\activate
- For Linux and macOS:
-
Install the maturin package:
pip install maturin
-
Build and install the Rust bindings, so that they will be accessible from Python code:
maturin develop --release
-
Run the test demo from your environment:
Pay attention to be in the root folder of the project, as the two bindings demo have hard coded paths in them, and an execution from a different folder would result in an error.
python3 ./py_onnx_rustime/test_onnx_rustime.py
You will be greeted with a barebone interactive menu, similar to the previous one:
After selecting the options, Python will call the Rust-powered functions for loading the chosen model & data, and running the model.
To set up the NodeJS environment for Rust bindings, follow these steps:
-
Install al the node module dependencies
cd <your root folder>/js_onnx_rustime/ npm i
-
Run the test demo from your environment:
Pay attention to be in the root folder of the project, as the two bindings demo have hard coded paths in them, and an execution from a different folder would result in an error.
node ./js_onnx_rustime/test_onnx_rustime.js
You will be greeted with a barebone interactive menu, similar to the previous one:
That's all! Now we will dwell into more technical details about the implementation. The following sections could be very useful to you if you seek to contribute to the project or simply understand it. π
The following outlines the structure of the ONNX Rustime project. We will delve into the specifics of each part, providing insights into their purpose and how they integrate into the larger framework.
.
βββ π js_onnx_rustime // javascript demo project
β βββ package.json
β βββ test_onnx-rustime.js
βββ π models // pre-trained models with test data
β βββ π bvlcalexnet-12
β βββ π caffenet-12
β βββ π mnist-8
β βββ π resnet18-v2-7
β βββ π resnet152-v2-7
β βββ π squeezenet1.0-12
β βββ π zfnet512-12
βββ π py_onnx_rustime // python demo project
β βββ π rust-binding
β βββ onnx_rustime_lib.pyi
β βββ test_onnx_rustime.py
βββ π screenshots
βββ π src
β βββ π onnx_rustime
β β βββ π backend // helper functions, parser, preprocessing, runtime functionalities
β β βββ π onnx_proto // ONNX data structures
β β βββ π ops // supported operations
β β βββ mod.rs
β β βββ shared.rs // global variable for verbose running
β βββ display.rs // display & menu functionalities
β βββ lib.rs
β βββ main.rs
βββ π third_party
β βββ π onnx // ONNX .proto files (more on this later!)
βββ build.rs
βββ Cargo.lock
βββ Cargo.toml
βββ LICENSE
βββ README.md
In ONNX Rustime, several Rust structures play pivotal roles in the operation of the runtime, especially in the file src/onnx_rustime/onnx_proto/onnx_ml_proto3.rs
. These structures provide a layer of abstraction over the underlying data. Here are the most commonly used:
Structure | Description |
---|---|
ModelProto | Represents an entire ONNX model. It contains metadata about the model (like version and domain) and the actual data, which is the GraphProto . |
GraphProto | Describes a computational graph. It contains nodes, inputs, outputs, and initializers (among other things). This is where the actual computation logic of the model is defined. |
TensorProto | Represents a tensor value. It can be an input, output, or an initializer. Contains data type, shape, and actual tensor values. |
NodeProto | Represents a node in the computational graph. Each node has an operator (like add, multiply), inputs, and outputs. |
ValueInfoProto | Contains metadata about a tensor, such as its name, data type, and shape, but not the actual tensor values. |
These Rust structures are generated from Protocol Buffers (often abbreviated as "protobuf") definitions specific to the ONNX framework. Protocol Buffers, developed by Google, is a method to serialize structured data, akin to XML or JSON.
The protobuf files are stored in third_party/onnx
.
They are in a .proto
and .proto3
format: the .proto3
syntax is the latest version of Protocol Buffers.
Note: The onnx.proto
files are sourced from the official ONNX repository and can be downloaded here.
Code Generation with build.rs
The Rust structures have been generated using the build.rs
script, a customary build script in Rust projects. Here's a brief overview of the process:
-
Setting Up: The script initializes the environment and verifies the presence of
protoc
(the Protocol Buffers compiler) on the system. -
Defining Source Directory: The
src_dir
macro specifies the directory containing the.proto
and.proto3
files. -
Code Generation: Upon detecting
protoc
, the script employs theprotoc_rust
crate to produce Rust code from the designated.proto3
files. This generated Rust code is stored in thesrc/onnx_rustime/onnx_proto
directory. The primary files for this process are:onnx-ml.proto3
onnx-operators-ml.proto3
onnx-data.proto3
Note on Compilation
The necessary libraries are already pre-compiled, so you don't need to have protoc
installed for general use. However, if you wish to compile them on your own (perhaps to integrate a new version of the protobuf), you can:
- Download
protoc
for your system from the official Protocol Buffers GitHub repository. - Obtain the desired
.proto
files from the official ONNX GitHub repository. - Modify the paths in the
build.rs
script to point to your new.proto
files. - Run
cargo build
to compile.
This flexibility ensures that you can always stay updated with the latest developments in the Protocol Buffers ecosystem.
The ONNX Parser in ONNX Rustime, as defined in src/onnx_rustime/backend/parser.rs
, provides functionalities to interpret and handle ONNX data structures. It leverages the Protocol Buffers format to serialize and deserialize ONNX models and data. The OnnxParser
struct encapsulates the following functionalities.
-
Model Loading: With
load_model
, you can load an ONNX model from a specified path. This function reads the file, deserializes it, and returns aModelProto
structure representing the entire ONNX model.pub fn load_model(path: String) -> Result<ModelProto, OnnxError>
-
Data Loading: The
load_data
function allows you to load tensor data from a given path. It returns aTensorProto
structure, which represents a tensor value in ONNX.pub fn load_data(path: String) -> Result<TensorProto, OnnxError>
-
Model and Data Saving: Although not primarily used, the parser also offers functionalities to save a given
ModelProto
orTensorProto
back to a file. This can be useful for scenarios where modifications to the model or data are made and need to be persisted.pub fn save_model(model: &ModelProto, path: String) -> Result<(), OnnxError> pub fn save_data(tensor: &TensorProto, path: String) -> Result<(), OnnxError>
-
Data Parsing: The parser can convert raw byte data into floating-point values (
parse_raw_data_as_floats
) and 64-bit integers (parse_raw_data_as_ints64
). This is essential for interpreting tensor data stored in the raw byte format inside theTensorProto
.pub fn parse_raw_data_as_floats(raw_data: &[u8]) -> Vec<f32> pub fn parse_raw_data_as_ints64(raw_data: &[u8]) -> Vec<i64>
ONNX Rustime offers a streamlined way to execute ONNX models. Here's a breakdown of the core functions that facilitate this process:
The run
function serves as the primary gateway for executing an ONNX model. It processes the nodes of the model's graph in the sequence they are defined, manages initializers, and ensures the output of one node is correctly routed as the input for subsequent nodes.
pub fn run(
model: &ModelProto,
input_tensor: TensorProto
) -> TensorProto;
- It begins by extracting the graph from the provided model.
- Each node in the graph is executed in sequence in the
run_node
function. - A progress bar provides a visual representation of the node execution process.
- The function concludes by returning the output tensor for the entire model.
The run_node
function is responsible for executing a specific node within the ONNX graph. It identifies the node's operation type and then invokes the corresponding execution function, supplying the necessary inputs and initializers.
fn run_node(
node: &NodeProto,
inputs: &Vec<&TensorProto>,
initializers: &Vec<&TensorProto>,
) -> Result<TensorProto, OnnxError>;
- The node's operation type is mapped to its execution function.
- The node is executed based on its operation type using the provided inputs.
- If the operation type isn't recognized, an error is returned.
To execute a network, load your ONNX model and input tensor, and then call the run
function. Ensure your model and input tensor are compatible and that the model's operations have been implemented.
let model = OnnxParser::load_model(model_path).unwrap();
let input = OnnxParser::load_data(input_path).unwrap();
let expected_output = OnnxParser::load_data(output_path).unwrap();
// Run the model
let predicted_output = run(&model, input);
In ONNX Rustime, operations are the backbone of the neural network execution. These operations are housed within the src/onnx_rustime/ops
directory. At the heart of these operations is the ndarray
library, a versatile tool for array computations in Rust.
When an operation is invoked, TensorProto
objects are provided as inputs. These objects are then converted into ndarray
arrays for efficient computation inside the operation function. This conversion is facilitated by utility functions such as:
tensor_proto_to_ndarray<T: TensorType>(
tensor: &TensorProto
) -> Result<ArrayD<T::DataType>, OnnxError>
convert_to_output_tensor(
node: &NodeProto,
result: ArrayD<f32>
) -> Result<TensorProto, OnnxError>
-
Pre-operation, the function
tensor_proto_to_ndarray
takes aTensorProto
object and converts it into anndarray
array of a specified type. -
Post-operation, the function
convert_to_output_tensor
assists in converting the computedndarray
array back into aTensorProto
object, ensuring consistency with the ONNX standard. The node is needed to extract the output name, which will be used by next operations to identify their inputs.
All the operations and utility functions return a Result
type, which encapsulates either the desired output or an OnnxError
. This specialized error type, defined in src/onnx_rustime/backend/helper.rs
, provides detailed error messages tailored to various scenarios that might arise during the execution of ONNX operations.
Add: Element-wise tensor addition.
pub fn add(
inputs: &Vec<&TensorProto>,
initializers: Option<&Vec<&TensorProto>>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Batch Normalization: Normalizes the activations of a given input volume.
pub fn batch_normalization(
input: &TensorProto,
initializers: &Vec<&TensorProto>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Concat: Concatenates tensors along a specified axis.
pub fn concat(
inputs: &Vec<&TensorProto>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Conv: Fundamental convolution operation for CNNs.
We adapted the original implementation of convolution-rs in order to support multiple batch convolution, with group selection and dilation.
pub fn conv(
input: &TensorProto,
initializers: &Vec<&TensorProto>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Dropout: Regularization technique where randomly selected neurons are ignored during training.
pub fn dropout(
input: &TensorProto,
initializers: Option<&Vec<&TensorProto>>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Exp: Computes the exponential of the given input tensor.
pub fn exp(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Flatten: Flattens the input tensor into a 2D matrix.
pub fn flatten(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Gemm: General Matrix Multiplication. Computes matrix multiplication and possibly followed by addition.
pub fn gemm(
inputs: &Vec<&TensorProto>,
initializers: Option<&Vec<&TensorProto>>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Global Average Pool: Computes the average of the entire input tensor.
pub fn global_average_pool(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
LRN: Local Response Normalization used in deep learning.
pub fn lrn(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
MatMul: Matrix multiplication operation.
pub fn matmul(
inputs: &Vec<&TensorProto>,
initializers: &Vec<&TensorProto>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
MaxPool: Down-samples an input representation using max pooling.
pub fn maxpool(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Reduce Sum: Computes the sum of all or specific axes of the input tensor.
pub fn reduce_sum(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
ReLU: Rectified Linear Unit activation function.
pub fn relu(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Reshape: Reshapes the input tensor to a new shape.
pub fn reshape(
input: Option<&TensorProto>,
initializers: &Vec<&TensorProto>,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
Softmax: Computes the softmax activations for the input tensor.
pub fn softmax(
input: &TensorProto,
node: &NodeProto,
) -> Result<TensorProto, OnnxError>;
For developers keen on extending ONNX Rustime's capabilities, adding new operations is very easy:
-
Implement the Operation: Create a new file inside the
src/onnx_rustime/ops
directory and define your operation. Ensure you utilize the utility functions for consistent tensor conversions. -
Integrate with the Execution Flow: Modify the
run_node
function to recognize and execute your newly added operation.
This modular and developer-friendly design ensures that ONNX Rustime remains extensible, catering to evolving neural network architectures and operations.
ONNX Rustime simplifies data preprocessing for the ImageNet dataset.
You can select any image (in all common formats) and produce a serialized pb file. This serialized pb file is a representation of the pre processed image.
You can find the corresponding code on src/backend/pre_processing.rs.
You can then use this serialized pb file as an input for your network to test it out. Currently, this preprocessing step works best for the ResNet model.
Our preprocessing pipeline for the ImageNet dataset involves the following steps:
- Rescaling: We begin by rescaling the image to a size of 256 pixels while maintaining the original image's proportions.
- Center Cropping: Next, we perform center cropping, reducing the image size to a standardized 224 by 224 dimensions.
- Data Type Conversion: Following cropping, we convert pixel values to floating-point data type and normalize them.
- Channel Separation: The RGB image is split into three separate channels to create a tensor of shape 3x224x224.
- Protobuf Conversion: Finally, we convert the preprocessed image into a protobuf file. This file can be used as input for further executions without requiring additional preprocessing steps.
This automatic data preprocessing enhances the efficiency and consistency of working with the ImageNet dataset, facilitating seamless integration into your ResNet model training pipeline.
Here is an example of the preprocessing execution:
The ONNX Rustime project provides a basic integration between Python and Rust, allowing users to harness the power of Rust's performance and safety while working within the Python ecosystem. This integration is achieved using the PyO3 library, which facilitates the creation of Python modules and native extensions in Rust.
The binding exposes several functions that allow Python users to interact with the ONNX runtime implemented in Rust. Key data structures like ModelProto
and TensorProto
from the ONNX specification are wrapped in IDs. These IDs act as opaque pointers, abstracting away the underlying Rust details from the Python side. This design ensures a clean separation between the two languages and hides the intricacies of the Rust implementation.
Here's a brief overview of the Python functions generated by the binding:
py_load_model(path: str) -> int
- Loads an ONNX model from the provided path.
- Returns an ID corresponding to the loaded
ModelProto
.
py_load_data(path: str) -> int
- Loads ONNX data (
TensorProto
) from the provided path. - Returns an ID corresponding to the loaded
TensorProto
.
py_print_data(data_id: int) -> None
- Prints the ONNX data (
TensorProto
) associated with the provided ID.
py_run(model_id: int, input_data_id: int, verbose: bool) -> int
- Runs the ONNX model with the provided input data.
- Returns an ID corresponding to the output
TensorProto
.
py_display_outputs(predicted_data_id: int, expected_data_id: int) -> None
- Displays the predicted and expected outputs in a user-friendly format.
Within the lib.rs
file, the Rust functions are exposed to Python using the pymodule
mechanism. The module is named onnx_rustime_lib
, aligning with the name specified in the Cargo.toml
lib section.
Here's the module definition:
#[pymodule]
fn onnx_rustime_lib(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_load_model, m)?)?;
m.add_function(wrap_pyfunction!(py_load_data, m)?)?;
m.add_function(wrap_pyfunction!(py_print_data, m)?)?;
m.add_function(wrap_pyfunction!(py_run, m)?)?;
m.add_function(wrap_pyfunction!(py_display_outputs, m)?)?;
Ok(())
}
When using this library in Python, you should import it as:
import onnx_rustime_lib
This allows you to access and utilize the defined Rust functions seamlessly within your Python environment.
To use the Python-Rust binding, you'll first need to set up the Python environment as described in the main README section. Once the environment is set up, you can directly call the provided Python functions to interact with the ONNX runtime.
For example, to load an ONNX model and run it with some input data:
import onnx_rustime_lib as onnx_rustime
model_id = onnx_rustime.py_load_model("path_to_model.onnx")
data_id = onnx_rustime.py_load_data("path_to_input_data.pb")
output_id = onnx_rustime.py_run(model_id, data_id, verbose=True)
As discussed in the usage section, a test demo example is provided in ./py_onnx_rustime/test_onnx_rustime.py
.
For more details on how Rust and Python integration works, and to dive deeper into the capabilities of PyO3, visit the official PyO3 documentation.
The ONNX Rustime project provides a basic integration between JavaScript and Rust, allowing users to harness the power of Rust's performance and safety while working within the JavaScript ecosystem. This integration is achieved using the neon library, which facilitates the creation of JavaScript modules and native extensions in Rust.
The binding exposes several functions that allow JavaScript users to interact with the ONNX runtime implemented in Rust. Key data structures like ModelProto
and TensorProto
from the ONNX specification are wrapped in IDs. These IDs act as opaque pointers, abstracting away the underlying Rust details from the JavaScript side. This design ensures a clean separation between the two languages and hides the intricacies of the Rust implementation.
All the function arguments are passed behind a ModuleContext struct, and outputs are wrapped around a JSResult struct.
Here's a brief overview of the JavaScript functions generated by the binding. First you find the Rust function signature and then the corresponding JavaScript exposed function.
js_load_model(cx: mut ModuleContext) -> JsResult<JsNumber>
js_load_model(path: str) -> int
- Loads an ONNX model from the provided path.
- Returns an ID corresponding to the loaded
ModelProto
.
js_load_data(cx: mut ModuleContext) -> JsResult<JsNumber>
js_load_data(path: str) -> int
- Loads ONNX data (
TensorProto
) from the provided path. - Returns an ID corresponding to the loaded
TensorProto
.
js_print_data(cx: mut ModuleContext) -> JsResult<JsUndefined>
js_print_data(data_id: int) -> undefined
- Prints the ONNX data (
TensorProto
) associated with the provided ID.
js_run(cx: mut ModuleContext) -> JsResult<JsNumber>
js_run(model_id: int, input_data_id: int, verbose: bool) -> int
- Runs the ONNX model with the provided input data.
- Returns an ID corresponding to the output
TensorProto
.
js_display_outputs(cx: mut ModuleContext) -> JsResult<JsUndefined>
js_display_outputs(predicted_data_id: int, expected_data_id: int) -> undefined
- Displays the predicted and expected outputs in a user-friendly format.
The module is named onnx_rustime_lib
, aligning with the name specified in the Cargo.toml
lib section.
Here's the module definition:
#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("js_load_model", js_load_model)?;
cx.export_function("js_load_data", js_load_data)?;
cx.export_function("js_print_data", js_print_data)?;
cx.export_function("js_run", js_run)?;
cx.export_function("js_display_outputs", js_display_outputs)?;
Ok(())
}
When using this library in JavaScript, you should import it as:
const onnx_rustime = require(".");
This allows you to access and utilize the defined Rust functions seamlessly within your JavaScript environment.
To use the JavaScript-Rust binding, you'll first need to install all the required node modules, as described in the main README section. Once the environment is set up, you can directly call the provided JavaScript functions to interact with the ONNX runtime.
For example, to load an ONNX model and run it with some input data:
const onnx_rustime = require(".");
model_id = onnx_rustime.js_load_model("path_to_model.onnx");
data_id = onnx_rustime.js_load_data("path_to_input_data.pb");
output_id = onnx_rustime.js_run(model_id, data_id, True);
As discussed in the usage section, a test demo example is provided in ./js_onnx_rustime/test_onnx_rustime.js
.
For more details on how Rust and JavaScript integration works, and to dive deeper into the capabilities of neon, visit the official neon documentation.
Feel free to contribute to this project by extending its functionalities, adding more operations, or improving the existing ones. Your contributions are always welcome!