File tree Expand file tree Collapse file tree 3 files changed +14
-3
lines changed
src/modeci_mdf/interfaces/pytorch Expand file tree Collapse file tree 3 files changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -86,7 +86,7 @@ def forward(self, input: torch.Tensor):
8686
8787# Export the model
8888fn = "ABCD_from_torch.onnx"
89- torch_out = torch .onnx ._export (
89+ torch .onnx .export (
9090 m_abcd , # model being run
9191 input , # model input (or a tuple for multiple inputs)
9292 fn , # where to save the model (can be a file or file-like object)
Original file line number Diff line number Diff line change @@ -89,7 +89,7 @@ optional =
8989 Jinja2<3.1
9090 torchviz
9191 netron
92- torch<2.2.0, >=1.11.0
92+ torch>=1.11.0
9393 torchvision
9494 h5py
9595
Original file line number Diff line number Diff line change 55
66 Mike He, "From Models to Computation Graphs (Part I)", https://ad1024.space/articles/22
77"""
8+
89import inspect
910import logging
1011
1314
1415import torch
1516
17+
1618# We need to monkey patch the torch._C.Node class to add a __getitem__ method
1719# This is for torch 2.0
1820# From https://github.com/openai/CLIP/issues/79#issuecomment-1624202950
@@ -523,7 +525,16 @@ def pytorch_to_mdf(
523525 # Call out to a part of the ONNX exporter that simiplifies the graph before ONNX export.
524526 from torch .onnx .utils import _model_to_graph
525527 from torch .onnx import TrainingMode
526- from torch .onnx .symbolic_helper import _set_opset_version
528+
529+ # Seems they got rid of _set_opset_version in 2.2 or something, can't find a better way to do this
530+ try :
531+ from torch .onnx .symbolic_helper import _set_opset_version
532+ except ImportError :
533+
534+ def _set_opset_version (version ):
535+ from torch .onnx ._globals import GLOBALS
536+
537+ GLOBALS .export_onnx_opset_version = version
527538
528539 try :
529540 from torch .onnx .symbolic_helper import _export_onnx_opset_version
You can’t perform that action at this time.
0 commit comments