-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: load a model entirely from an onnx file and build circuit at runtime #25
Conversation
jasonmorton
commented
Sep 25, 2022
- Fully runtime dynamic configuration of a circuit (OnnxCircuit)
- Loop over nodes / layers to configure and layout
- Loading of onnx file into a data structure (only a few ops implemented), OnnxModel, OnnxNode, and OnnxModelConfig
Can confirm that I am able to reproduce the shape inference issues for a 3 layer MLP network with ReLU non-linearities between layers. Printing out all the ONNX nodes -- there is sufficient information to infer the hidden / inner layer shapes for sure. OnnxNode { node: Node { id: 0, name: "input", inputs: [], op: Source, outputs: [?,4,F32 >7/0] } }
OnnxNode { node: Node { id: 1, name: "hidden3.weight", inputs: [], op: Const(4,4,F32 -0.6978378, -0.8618335, -0.560139, 0.67565507, 0.75508606, -1.0828302, 0.1357851, -0.48875993, 0.5060201, 0.27603242, -1.2781131, -0.18486828...), outputs: [4,4,F32 -0.6978378, -0.8618335, -0.560139, 0.67565507, 0.75508606, -1.0828302, 0.1357851, -0.48875993, 0.5060201, 0.27603242, -1.2781131, -0.18486828... >11/1] } }
OnnxNode { node: Node { id: 2, name: "hidden2.bias", inputs: [], op: Const(4,F32 -0.029132247, 0.27592283, -0.3716687, 0.029936492), outputs: [4,F32 -0.029132247, 0.27592283, -0.3716687, 0.029936492 >9/2] } }
OnnxNode { node: Node { id: 3, name: "hidden1.weight", inputs: [], op: Const(4,4,F32 -0.79067045, -0.1981668, 0.38580352, -1.0893693, -0.35157114, -1.2757933, -0.3365852, 0.36804888, 1.0542545, -0.57707036, 0.59439397, -0.44970307...), outputs: [4,4,F32 -0.79067045, -0.1981668, 0.38580352, -1.0893693, -0.35157114, -1.2757933, -0.3365852, 0.36804888, 1.0542545, -0.57707036, 0.59439397, -0.44970307... >7/1] } }
OnnxNode { node: Node { id: 4, name: "hidden3.bias", inputs: [], op: Const(4,F32 -0.44558668, -0.34988558, -0.03433931, -0.0068350434), outputs: [4,F32 -0.44558668, -0.34988558, -0.03433931, -0.0068350434 >11/2] } }
OnnxNode { node: Node { id: 5, name: "hidden2.weight", inputs: [], op: Const(4,4,F32 0.28882173, 1.0820657, -0.25749648, -0.8242642, -0.7267107, -0.08238286, 1.0038052, -0.6763725, 0.38093626, 0.702172, 0.8628348, 0.7857216...), outputs: [4,4,F32 0.28882173, 1.0820657, -0.25749648, -0.8242642, -0.7267107, -0.08238286, 1.0038052, -0.6763725, 0.38093626, 0.702172, 0.8628348, 0.7857216... >9/1] } }
OnnxNode { node: Node { id: 6, name: "hidden1.bias", inputs: [], op: Const(4,F32 -0.069352865, -0.39855963, -0.4257033, -0.27501678), outputs: [4,F32 -0.069352865, -0.39855963, -0.4257033, -0.27501678 >7/2] } }
OnnxNode { node: Node { id: 7, name: "Gemm_0", inputs: [0/0>, 3/0>, 6/0>], op: Gemm { alpha: 1.0, beta: 1.0, trans_a: false, trans_b: true }, outputs: [..,? >8/0] } }
OnnxNode { node: Node { id: 8, name: "Relu_1", inputs: [7/0>], op: Clip(Some(0.0), None), outputs: [..,? >9/0] } }
OnnxNode { node: Node { id: 9, name: "Gemm_2", inputs: [8/0>, 5/0>, 2/0>], op: Gemm { alpha: 1.0, beta: 1.0, trans_a: false, trans_b: true }, outputs: [..,? >10/0] } }
OnnxNode { node: Node { id: 10, name: "Relu_3", inputs: [9/0>], op: Clip(Some(0.0), None), outputs: [..,? >11/0] } }
OnnxNode { node: Node { id: 11, name: "Gemm_4", inputs: [10/0>, 1/0>, 4/0>], op: Gemm { alpha: 1.0, beta: 1.0, trans_a: false, trans_b: true }, outputs: [..,? >12/0] } }
OnnxNode { node: Node { id: 12, name: "Relu_5", inputs: [11/0>], op: Clip(Some(0.0), None), outputs: [?,4,F32 ] } }
It is just that the ordering of nodes is a bit annoying and we'll need to do some slightly inelegant linking between weight / bias shapes to ReLu Gemm shapes if there isn't a simple flag / function that can do it for us. Part of the issue is that ReLu layers aren't "layers" per se and are just element-wise ops so don't carry shape information. If I remove the ReLU operations from the 3 layer network (so it is solely affine ops) then we are able to infer the shape correctly. |
also some of the inputs to the ReLU layers could potentially be OOR for lookup tables error: lookup input does not exist in table
(L0, L1) ∉ (F14, F15)
Lookup 'lk' inputs:
L0 = x1 * x0
^
| Cell layout in region 'Elementwise':
| | Offset | A3 | F27|
| +--------+----+----+
| | 0 | x0 | x1 | <--{ Lookup 'lk' inputs queried here
|
| Assigned cell values:
| x0 = 0x40000000000000000000000000000000224698fc094cf91b992d30ecffff1d13
| x1 = 1
L1 = x1 * x0 this is something I'm finding on randomly initialized models |
A potential solution to the first problem is to define the ONNXModel struct such that it has a #[derive(Clone, Debug)]
pub struct OnnxModel {
pub model: Graph<InferenceFact, Box<dyn InferenceOp>>, // The raw Tract data structure
pub onnx_nodes: Vec<OnnxNode>, // Wrapped nodes with additional methods and potentially data (e.g. quantization)
pub last_state: Vec<usize>,
} When configuring an affine layer we can then set: self.last_state = Vec::from([out_dim]); and non-linear layers (which don't contain shape information, can then be configured using that state information: let length = some_function(self.last_shape);
let conf: EltwiseConfig<F, BITS, ReLu<F>> = EltwiseConfig::configure(
meta,
advices.get_slice(&[0..length], &[length]),
None,
); |
As discussed @jasonmorton the best way to infer the shapes might be to just to run a forward pass in a fully compatible system (e.g. tract) and trace it. We could do this by creating a custom version of the tract inference function, which outputs nodes + their respective shapes / dims instead of an output. The single inference pass will probably maximize compatibility with TF / JAX etc... |
My plan for OOR is to use the OnnxNode struct to store quantization and bounds information, and automatically insert DivideBy layers (or even better convert activation layers to rescale-and-activate layers) when we can statically see we risk OOR. Basically as a part of the configuration pass, there should be a quantization pass that makes adjustments to the quantization scheme. |
7865b7b
to
2eb553b
Compare
Partially resolves #16. |