Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: load a model entirely from an onnx file and build circuit at runtime #25

Merged
merged 35 commits into from
Oct 4, 2022

Conversation

jasonmorton
Copy link
Member

  • Fully runtime dynamic configuration of a circuit (OnnxCircuit)
  • Loop over nodes / layers to configure and layout
  • Loading of onnx file into a data structure (only a few ops implemented), OnnxModel, OnnxNode, and OnnxModelConfig

@alexander-camuto alexander-camuto marked this pull request as draft September 26, 2022 09:54
@alexander-camuto
Copy link
Collaborator

alexander-camuto commented Sep 26, 2022

Can confirm that I am able to reproduce the shape inference issues for a 3 layer MLP network with ReLU non-linearities between layers.

Printing out all the ONNX nodes -- there is sufficient information to infer the hidden / inner layer shapes for sure.

OnnxNode { node: Node { id: 0, name: "input", inputs: [], op: Source, outputs: [?,4,F32 >7/0] } }
OnnxNode { node: Node { id: 1, name: "hidden3.weight", inputs: [], op: Const(4,4,F32 -0.6978378, -0.8618335, -0.560139, 0.67565507, 0.75508606, -1.0828302, 0.1357851, -0.48875993, 0.5060201, 0.27603242, -1.2781131, -0.18486828...), outputs: [4,4,F32 -0.6978378, -0.8618335, -0.560139, 0.67565507, 0.75508606, -1.0828302, 0.1357851, -0.48875993, 0.5060201, 0.27603242, -1.2781131, -0.18486828... >11/1] } }
OnnxNode { node: Node { id: 2, name: "hidden2.bias", inputs: [], op: Const(4,F32 -0.029132247, 0.27592283, -0.3716687, 0.029936492), outputs: [4,F32 -0.029132247, 0.27592283, -0.3716687, 0.029936492 >9/2] } }
OnnxNode { node: Node { id: 3, name: "hidden1.weight", inputs: [], op: Const(4,4,F32 -0.79067045, -0.1981668, 0.38580352, -1.0893693, -0.35157114, -1.2757933, -0.3365852, 0.36804888, 1.0542545, -0.57707036, 0.59439397, -0.44970307...), outputs: [4,4,F32 -0.79067045, -0.1981668, 0.38580352, -1.0893693, -0.35157114, -1.2757933, -0.3365852, 0.36804888, 1.0542545, -0.57707036, 0.59439397, -0.44970307... >7/1] } }
OnnxNode { node: Node { id: 4, name: "hidden3.bias", inputs: [], op: Const(4,F32 -0.44558668, -0.34988558, -0.03433931, -0.0068350434), outputs: [4,F32 -0.44558668, -0.34988558, -0.03433931, -0.0068350434 >11/2] } }
OnnxNode { node: Node { id: 5, name: "hidden2.weight", inputs: [], op: Const(4,4,F32 0.28882173, 1.0820657, -0.25749648, -0.8242642, -0.7267107, -0.08238286, 1.0038052, -0.6763725, 0.38093626, 0.702172, 0.8628348, 0.7857216...), outputs: [4,4,F32 0.28882173, 1.0820657, -0.25749648, -0.8242642, -0.7267107, -0.08238286, 1.0038052, -0.6763725, 0.38093626, 0.702172, 0.8628348, 0.7857216... >9/1] } }
OnnxNode { node: Node { id: 6, name: "hidden1.bias", inputs: [], op: Const(4,F32 -0.069352865, -0.39855963, -0.4257033, -0.27501678), outputs: [4,F32 -0.069352865, -0.39855963, -0.4257033, -0.27501678 >7/2] } }
OnnxNode { node: Node { id: 7, name: "Gemm_0", inputs: [0/0>, 3/0>, 6/0>], op: Gemm { alpha: 1.0, beta: 1.0, trans_a: false, trans_b: true }, outputs: [..,? >8/0] } }
OnnxNode { node: Node { id: 8, name: "Relu_1", inputs: [7/0>], op: Clip(Some(0.0), None), outputs: [..,? >9/0] } }
OnnxNode { node: Node { id: 9, name: "Gemm_2", inputs: [8/0>, 5/0>, 2/0>], op: Gemm { alpha: 1.0, beta: 1.0, trans_a: false, trans_b: true }, outputs: [..,? >10/0] } }
OnnxNode { node: Node { id: 10, name: "Relu_3", inputs: [9/0>], op: Clip(Some(0.0), None), outputs: [..,? >11/0] } }
OnnxNode { node: Node { id: 11, name: "Gemm_4", inputs: [10/0>, 1/0>, 4/0>], op: Gemm { alpha: 1.0, beta: 1.0, trans_a: false, trans_b: true }, outputs: [..,? >12/0] } }
OnnxNode { node: Node { id: 12, name: "Relu_5", inputs: [11/0>], op: Clip(Some(0.0), None), outputs: [?,4,F32 ] } }

It is just that the ordering of nodes is a bit annoying and we'll need to do some slightly inelegant linking between weight / bias shapes to ReLu Gemm shapes if there isn't a simple flag / function that can do it for us.

Part of the issue is that ReLu layers aren't "layers" per se and are just element-wise ops so don't carry shape information. If I remove the ReLU operations from the 3 layer network (so it is solely affine ops) then we are able to infer the shape correctly.

@alexander-camuto
Copy link
Collaborator

also some of the inputs to the ReLU layers could potentially be OOR for lookup tables

error: lookup input does not exist in table
  (L0, L1)(F14, F15)

  Lookup 'lk' inputs:
    L0 = x1 * x0
    ^
    | Cell layout in region 'Elementwise':
    |   | Offset | A3 | F27|
    |   +--------+----+----+
    |   |    0   | x0 | x1 | <--{ Lookup 'lk' inputs queried here
    |
    | Assigned cell values:
    |   x0 = 0x40000000000000000000000000000000224698fc094cf91b992d30ecffff1d13
    |   x1 = 1

    L1 = x1 * x0

this is something I'm finding on randomly initialized models

@alexander-camuto
Copy link
Collaborator

A potential solution to the first problem is to define the ONNXModel struct such that it has a last_state attribute that can be accessed when calling configure_node. This could represent the size of the output of the last configured node or the layer; it could also include information about the type of last configured layer (if that ever becomes relevant).
For instance:

#[derive(Clone, Debug)]
pub struct OnnxModel {
    pub model: Graph<InferenceFact, Box<dyn InferenceOp>>, // The raw Tract data structure
    pub onnx_nodes: Vec<OnnxNode>, // Wrapped nodes with additional methods and potentially data (e.g. quantization)
    pub last_state: Vec<usize>,
}

When configuring an affine layer we can then set:

self.last_state = Vec::from([out_dim]);

and non-linear layers (which don't contain shape information, can then be configured using that state information:

let length = some_function(self.last_shape);

let conf: EltwiseConfig<F, BITS, ReLu<F>> = EltwiseConfig::configure(
                    meta,
                    advices.get_slice(&[0..length], &[length]),
                    None,
                );

@alexander-camuto
Copy link
Collaborator

As discussed @jasonmorton the best way to infer the shapes might be to just to run a forward pass in a fully compatible system (e.g. tract) and trace it. We could do this by creating a custom version of the tract inference function, which outputs nodes + their respective shapes / dims instead of an output. The single inference pass will probably maximize compatibility with TF / JAX etc...

@jasonmorton
Copy link
Member Author

jasonmorton commented Sep 26, 2022

My plan for OOR is to use the OnnxNode struct to store quantization and bounds information, and automatically insert DivideBy layers (or even better convert activation layers to rescale-and-activate layers) when we can statically see we risk OOR. Basically as a part of the configuration pass, there should be a quantization pass that makes adjustments to the quantization scheme.

@alexander-camuto
Copy link
Collaborator

alexander-camuto commented Oct 3, 2022

Partially resolves #16.

@alexander-camuto alexander-camuto marked this pull request as ready for review October 4, 2022 13:44
@alexander-camuto alexander-camuto self-requested a review October 4, 2022 13:44
@alexander-camuto alexander-camuto merged commit 29df95f into main Oct 4, 2022
@alexander-camuto alexander-camuto deleted the feat/onnx_file branch October 4, 2022 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants