|
27 | 27 | _LOGGER = logging.getLogger(__name__) |
28 | 28 |
|
29 | 29 | __all__ = [ |
| 30 | + "onnx_includes_external_data", |
30 | 31 | "save_onnx", |
31 | 32 | "validate_onnx", |
32 | 33 | "load_model", |
|
50 | 51 | ] |
51 | 52 |
|
52 | 53 |
|
| 54 | +def onnx_includes_external_data(model: ModelProto) -> bool: |
| 55 | + """ |
| 56 | + Check whether the ModelProto in memory includes the external |
| 57 | + data or not. |
| 58 | +
|
| 59 | + If the model.onnx does not contain the external data, then the |
| 60 | + initializers of the model are pointing to the external data file |
| 61 | + (they are not empty) |
| 62 | +
|
| 63 | + :param model: the ModelProto to check |
| 64 | + :return True if the model was loaded with external data, False otherwise. |
| 65 | + """ |
| 66 | + |
| 67 | + initializers = model.graph.initializer |
| 68 | + |
| 69 | + is_data_saved_to_disk = any( |
| 70 | + initializer.external_data for initializer in initializers |
| 71 | + ) |
| 72 | + is_data_included_in_model = not is_data_saved_to_disk |
| 73 | + |
| 74 | + return is_data_included_in_model |
| 75 | + |
| 76 | + |
53 | 77 | def save_onnx( |
54 | 78 | model: ModelProto, model_path: str, external_data_file: Optional[str] = None |
55 | 79 | ) -> bool: |
@@ -121,6 +145,14 @@ def validate_onnx(model: Union[str, ModelProto]): |
121 | 145 | return |
122 | 146 | onnx.checker.check_model(onnx_model) |
123 | 147 | except Exception as err: |
| 148 | + if not onnx_includes_external_data(model): |
| 149 | + _LOGGER.warning( |
| 150 | + "Attempting to validate an in-memory ONNX model " |
| 151 | + "that has been loaded without external data. " |
| 152 | + "This is currently not supported by the ONNX checker. " |
| 153 | + "The validation will be skipped." |
| 154 | + ) |
| 155 | + return |
124 | 156 | raise ValueError(f"Invalid onnx model: {err}") |
125 | 157 |
|
126 | 158 |
|
|
0 commit comments