Skip to content

Commit

Permalink
fix: protobuf 2gb limit when checking onnx (#811)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jul 26, 2024
1 parent a1bd9b8 commit c8908fa
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import onnx
import onnxoptimizer
import torch
from onnx import checker, helper
from onnx import helper

from ..common.debugging import assert_true
from .onnx_utils import (
IMPLEMENTED_ONNX_OPS,
check_onnx_model,
execute_onnx_with_numpy,
execute_onnx_with_numpy_trees,
get_op_type,
Expand Down Expand Up @@ -204,7 +205,7 @@ def preprocess_onnx_model(onnx_model: onnx.ModelProto, check_model: bool) -> onn
stacklevel=2,
)

checker.check_model(onnx_model)
check_onnx_model(onnx_model)

# Optimize ONNX graph
# List of all currently supported onnx optimizer passes
Expand All @@ -217,13 +218,13 @@ def preprocess_onnx_model(onnx_model: onnx.ModelProto, check_model: bool) -> onn
"eliminate_unused_initializer",
]
equivalent_onnx_model = onnxoptimizer.optimize(onnx_model, onnx_passes)
checker.check_model(equivalent_onnx_model)
check_onnx_model(equivalent_onnx_model)

# Custom optimization
# ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix
# We manually do the optimization for this case
equivalent_onnx_model = fuse_matmul_bias_to_gemm(equivalent_onnx_model)
checker.check_model(equivalent_onnx_model)
check_onnx_model(equivalent_onnx_model)

# Check supported operators
required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node)
Expand Down
52 changes: 52 additions & 0 deletions src/concrete/ml/onnx/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,17 @@
# - mypy typing hints
# - existing and new ops implementation in separate file

import tempfile
from pathlib import Path

# Original file:
# https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py
from typing import Any, Callable, Dict, Optional, Tuple

import numpy
import onnx
from onnx import numpy_helper
from onnx.external_data_helper import convert_model_to_external_data

from .ops_impl import (
numpy_abs,
Expand Down Expand Up @@ -566,3 +570,51 @@ def remove_initializer_from_input(model: onnx.ModelProto): # pragma: no cover
inputs.remove(name_to_input[initializer.name])

return model


def check_onnx_model(onnx_model: onnx.ModelProto) -> None:
"""Check an ONNX model, handling large models (>2GB) by using external data.
Args:
onnx_model (onnx.ModelProto): The ONNX model to check.
Raises:
ValueError: If the model is too large (>2GB) or if there's another ValueError.
"""
# Create a copy of the input model
onnx_model_copy = onnx.ModelProto()
onnx_model_copy.CopyFrom(onnx_model)

try:
# Try to check the model copy directly
onnx.checker.check_model(onnx_model_copy)
except ValueError as e:
error_message = str(e)
if (
"Message onnx.ModelProto exceeds maximum protobuf size of 2GB:" in error_message
or "This protobuf of onnx model is too large (>2GB)" in error_message
):

# If the model is too large, use external data approach
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
model_path = temp_dir_path / "model.onnx"
external_data_path = temp_dir_path / "model_data.bin"

# Save the model copy with external data
convert_model_to_external_data(
onnx_model_copy, all_tensors_to_one_file=True, location=external_data_path.name
)
onnx.save_model(
onnx_model_copy,
str(model_path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_data_path.name,
)

# Check the model using the file path
onnx.checker.check_model(str(model_path))
else: # pragma: no cover
# If it is a different error, re-raise it
raise
42 changes: 42 additions & 0 deletions tests/onnx/test_onnx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Test ONNX utils."""

import numpy as np
import onnx
import pytest

from concrete.ml.onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT
from concrete.ml.onnx.onnx_utils import check_onnx_model


def test_check_onnx_model_large():
"""Test that check_onnx_model can handle models larger than 2GB."""

model = onnx.ModelProto()
graph = onnx.GraphProto()
graph.name = "LargeModel"

# Create a large tensor (slightly over the 2GB limit)
large_tensor = np.random.rand(1000, 1000, 550).astype(np.float32)
tensor_proto = onnx.numpy_helper.from_array(large_tensor, name="large_tensor")

graph.initializer.append(tensor_proto)
model.graph.CopyFrom(graph)

# Set ir_version
model.ir_version = onnx.IR_VERSION

# Add opset_import
opset = model.opset_import.add()
opset.version = OPSET_VERSION_FOR_ONNX_EXPORT

# Test that onnx.checker.check_model raises an exception
with pytest.raises(
ValueError, match="Message onnx.ModelProto exceeds maximum protobuf size of 2GB:"
):
onnx.checker.check_model(model)

# Our custom check_onnx_model should work fine
check_onnx_model(model)

# Call check_onnx_model a second time to ensure the original model wasn't modified
check_onnx_model(model)

0 comments on commit c8908fa

Please sign in to comment.