Skip to content

Commit cc9a796

Browse files
authored
Merge pull request #542 from ModECI/fix/pytorch_pin
Remove pytorch upper pin
2 parents a21e314 + 7c3129c commit cc9a796

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

examples/MDF/abcd_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def forward(self, input: torch.Tensor):
8686

8787
# Export the model
8888
fn = "ABCD_from_torch.onnx"
89-
torch_out = torch.onnx._export(
89+
torch.onnx.export(
9090
m_abcd, # model being run
9191
input, # model input (or a tuple for multiple inputs)
9292
fn, # where to save the model (can be a file or file-like object)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ optional =
8989
Jinja2<3.1
9090
torchviz
9191
netron
92-
torch<2.2.0,>=1.11.0
92+
torch>=1.11.0
9393
torchvision
9494
h5py
9595

src/modeci_mdf/interfaces/pytorch/importer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
Mike He, "From Models to Computation Graphs (Part I)", https://ad1024.space/articles/22
77
"""
8+
89
import inspect
910
import logging
1011

@@ -13,6 +14,7 @@
1314

1415
import torch
1516

17+
1618
# We need to monkey patch the torch._C.Node class to add a __getitem__ method
1719
# This is for torch 2.0
1820
# From https://github.com/openai/CLIP/issues/79#issuecomment-1624202950
@@ -523,7 +525,16 @@ def pytorch_to_mdf(
523525
# Call out to a part of the ONNX exporter that simiplifies the graph before ONNX export.
524526
from torch.onnx.utils import _model_to_graph
525527
from torch.onnx import TrainingMode
526-
from torch.onnx.symbolic_helper import _set_opset_version
528+
529+
# Seems they got rid of _set_opset_version in 2.2 or something, can't find a better way to do this
530+
try:
531+
from torch.onnx.symbolic_helper import _set_opset_version
532+
except ImportError:
533+
534+
def _set_opset_version(version):
535+
from torch.onnx._globals import GLOBALS
536+
537+
GLOBALS.export_onnx_opset_version = version
527538

528539
try:
529540
from torch.onnx.symbolic_helper import _export_onnx_opset_version

0 commit comments

Comments
 (0)