Skip to content
Closed
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1dfbb1c
Define save_onnx function.
Han123su Jun 10, 2024
4ce9b2c
Add saver_type.py (Saver class).
Han123su Jun 10, 2024
1f0fe90
Modify the parameter name in save_net_with_metadata function.
Han123su Jun 10, 2024
b799d14
Import new modules as needed.
Han123su Jun 10, 2024
50e9ae5
Modify export function and Add saver (mainly conversion and saving).
Han123su Jun 10, 2024
e1fbacd
Modify onnx_export and call _export() instead.
Han123su Jun 10, 2024
b390dbb
Modify ckpt_export: do specific processing before calling _export.
Han123su Jun 10, 2024
4079598
Modify trt_export: do specific processing before calling _export.
Han123su Jun 10, 2024
eb754ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
d4e20d1
Add license header in saver_type.py
Han123su Jun 10, 2024
3c5211c
Merge branch 'fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Jun 10, 2024
dd4a3d5
Autofix
Han123su Jun 10, 2024
5a9fedc
Add onnx in requirements-min.txt
Han123su Jun 11, 2024
f917432
Fix local variable 'inputs_' is assigned to but never used
Han123su Jun 11, 2024
e862405
Merge branch 'fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Jun 11, 2024
ac62f04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2024
f389a4d
Delete onnx in requirements-dev.txt
Han123su Jun 11, 2024
eb2f7e1
Change onnx_save location instead import onnx
Han123su Jun 11, 2024
623c86b
Merge branch 'fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Jun 11, 2024
d45dbce
Back to original requirements-dev.txt
Han123su Jun 11, 2024
2e74d6e
Delete import onnx in torchscript_utils.py
Han123su Jun 11, 2024
8b4766f
Fix import
Han123su Jun 11, 2024
9c396aa
Fix cannot import name 'save_onnx' from partially initialized module …
Han123su Jun 11, 2024
c12d65c
Fix Undefined name
Han123su Jun 11, 2024
9babcbc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2024
cb1d4d2
Fix incorrectly sorted and/or formatted.
Han123su Jun 11, 2024
57e3ba2
Modify
Han123su Jun 11, 2024
2f96bcf
Fix onnx.ModelProto not defined
Han123su Jun 11, 2024
b36426d
Merge branch 'dev' into fix-issue-6375
Han123su Jun 13, 2024
51b4445
Merge remote-tracking branch 'upstream/dev' into fix-issue-6375
Han123su Jun 13, 2024
4f4cb0e
Merge branch 'fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Jun 13, 2024
3d178a7
Merge remote-tracking branch 'upstream/dev' into fix-issue-6375
Han123su Jul 18, 2024
6459cc9
Back to original
Han123su Jul 18, 2024
6d4b8c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
42e08f2
Modify to get closer to target
Han123su Jul 20, 2024
61605bb
Merge branch 'fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Jul 20, 2024
182e20d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2024
b34f3c7
Merge branch 'dev' into fix-issue-6375
Han123su Jul 20, 2024
7d67953
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
608c407
modify
Han123su Jul 20, 2024
1b3a984
Merge branch 'fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Jul 20, 2024
87d736d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
from functools import partial
from monai.networks import (
convert_to_onnx,
convert_to_torchscript,
Expand Down Expand Up @@ -1159,6 +1160,7 @@ def verify_net_in_out(

def _export(
converter: Callable,
saver: Callable,
parser: ConfigParser,
net_id: str,
filepath: str,
Expand All @@ -1173,6 +1175,8 @@ def _export(
Args:
converter: a callable object that takes a torch.nn.module and kwargs as input and
converts the module to another type.
saver: a callable object that takes the converted model and a filepath as input and
saves the model to the specified location.
parser: a ConfigParser of the bundle to be converted.
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
Expand Down Expand Up @@ -1212,14 +1216,12 @@ def _export(
# add .json extension to all extra files which are always encoded as JSON
extra_files = {k + ".json": v for k, v in extra_files.items()}

save_net_with_metadata(
jit_obj=net,
filename_prefix_or_stream=filepath,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
more_extra_files=extra_files,
saver(
jit_obj = net,
filename_prefix_or_stream = filepath,
more_extra_files = extra_files,
)

logger.info(f"exported to file: {filepath}.")


Expand Down Expand Up @@ -1318,17 +1320,28 @@ def onnx_export(
input_shape_ = _get_fake_input_shape(parser=parser)

inputs_ = [torch.rand(input_shape_)]
net = parser.get_parsed_content(net_id_)
if has_ignite:
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
else:
ckpt = torch.load(ckpt_file_)
copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])

converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
onnx.save(onnx_model, filepath_)

def save_onnx(
jit_obj: torch.nn.Module,
filename_prefix_or_stream: str | IO[Any],
more_extra_files: None = None,
) -> None:

onnx.save(jit_obj, filename_prefix_or_stream)

_export(
convert_to_onnx,
save_onnx,
parser,
net_id=net_id_,
filepath=filepath_,
ckpt_file=ckpt_file_,
config_file=config_file_,
key_in_ckpt=key_in_ckpt_,
**converter_kwargs_,
)


def ckpt_export(
Expand Down Expand Up @@ -1449,8 +1462,14 @@ def ckpt_export(

converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
# Use the given converter to convert a model and save with metadata, config content

save_ts = partial(save_net_with_metadata, include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None))

_export(
convert_to_torchscript,
save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
Expand Down Expand Up @@ -1620,8 +1639,13 @@ def trt_export(
}
converter_kwargs_.update(trt_api_parameters)

save_ts = partial(save_net_with_metadata, include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None))

_export(
convert_to_trt,
save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
Expand Down