Skip to content

Commit facf176

Browse files
KumoLiuericspodpre-commit-ci[bot]
authored
Add compile support in SupervisedTrainer and SupervisedEvaluator (#7375)
Fixes # . ### Description Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator`. Convert to `torch.Tensor` internally. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 80be1c3 commit facf176

File tree

2 files changed

+99
-4
lines changed

2 files changed

+99
-4
lines changed

monai/engines/evaluator.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111

1212
from __future__ import annotations
1313

14+
import warnings
1415
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
1516

1617
import torch
1718
from torch.utils.data import DataLoader
1819

1920
from monai.config import IgniteInfo, KeysCollection
21+
from monai.data import MetaTensor
2022
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
2123
from monai.engines.workflow import Workflow
2224
from monai.inferers import Inferer, SimpleInferer
@@ -25,7 +27,7 @@
2527
from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
2628
from monai.utils.enums import CommonKeys as Keys
2729
from monai.utils.enums import EngineStatsKeys as ESKeys
28-
from monai.utils.module import look_up_option
30+
from monai.utils.module import look_up_option, pytorch_after
2931

3032
if TYPE_CHECKING:
3133
from ignite.engine import Engine, EventEnum
@@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator):
213215
`device`, `non_blocking`.
214216
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
215217
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
218+
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
219+
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
220+
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
221+
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
216222
217223
"""
218224

@@ -238,6 +244,8 @@ def __init__(
238244
decollate: bool = True,
239245
to_kwargs: dict | None = None,
240246
amp_kwargs: dict | None = None,
247+
compile: bool = False,
248+
compile_kwargs: dict | None = None,
241249
) -> None:
242250
super().__init__(
243251
device=device,
@@ -259,8 +267,16 @@ def __init__(
259267
to_kwargs=to_kwargs,
260268
amp_kwargs=amp_kwargs,
261269
)
262-
270+
if compile:
271+
if pytorch_after(2, 1):
272+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
273+
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
274+
else:
275+
warnings.warn(
276+
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
277+
)
263278
self.network = network
279+
self.compile = compile
264280
self.inferer = SimpleInferer() if inferer is None else inferer
265281

266282
def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:
@@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
288304
kwargs: dict = {}
289305
else:
290306
inputs, targets, args, kwargs = batch
307+
# FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026
308+
if self.compile:
309+
inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None
310+
if isinstance(inputs, MetaTensor):
311+
warnings.warn(
312+
"Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass."
313+
)
314+
inputs, inputs_meta, inputs_applied_operations = (
315+
inputs.as_tensor(),
316+
inputs.meta,
317+
inputs.applied_operations,
318+
)
319+
if isinstance(targets, MetaTensor):
320+
targets, targets_meta, targets_applied_operations = (
321+
targets.as_tensor(),
322+
targets.meta,
323+
targets.applied_operations,
324+
)
291325

292326
# put iteration outputs into engine.state
293327
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
@@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
298332
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
299333
else:
300334
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
335+
# copy back meta info
336+
if self.compile:
337+
if inputs_meta is not None:
338+
engine.state.output[Keys.IMAGE] = MetaTensor(
339+
inputs, meta=inputs_meta, applied_operations=inputs_applied_operations
340+
)
341+
engine.state.output[Keys.PRED] = MetaTensor(
342+
engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations
343+
)
344+
if targets_meta is not None:
345+
engine.state.output[Keys.LABEL] = MetaTensor(
346+
targets, meta=targets_meta, applied_operations=targets_applied_operations
347+
)
301348
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
302349
engine.fire_event(IterationEvents.MODEL_COMPLETED)
303350

monai/engines/trainer.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,23 @@
1111

1212
from __future__ import annotations
1313

14+
import warnings
1415
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
1516

1617
import torch
1718
from torch.optim.optimizer import Optimizer
1819
from torch.utils.data import DataLoader
1920

2021
from monai.config import IgniteInfo
22+
from monai.data import MetaTensor
2123
from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch
2224
from monai.engines.workflow import Workflow
2325
from monai.inferers import Inferer, SimpleInferer
2426
from monai.transforms import Transform
2527
from monai.utils import GanKeys, min_version, optional_import
2628
from monai.utils.enums import CommonKeys as Keys
2729
from monai.utils.enums import EngineStatsKeys as ESKeys
30+
from monai.utils.module import pytorch_after
2831

2932
if TYPE_CHECKING:
3033
from ignite.engine import Engine, EventEnum
@@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer):
125128
`device`, `non_blocking`.
126129
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
127130
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
128-
131+
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
132+
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
133+
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
134+
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
129135
"""
130136

131137
def __init__(
@@ -153,6 +159,8 @@ def __init__(
153159
optim_set_to_none: bool = False,
154160
to_kwargs: dict | None = None,
155161
amp_kwargs: dict | None = None,
162+
compile: bool = False,
163+
compile_kwargs: dict | None = None,
156164
) -> None:
157165
super().__init__(
158166
device=device,
@@ -174,8 +182,16 @@ def __init__(
174182
to_kwargs=to_kwargs,
175183
amp_kwargs=amp_kwargs,
176184
)
177-
185+
if compile:
186+
if pytorch_after(2, 1):
187+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
188+
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
189+
else:
190+
warnings.warn(
191+
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
192+
)
178193
self.network = network
194+
self.compile = compile
179195
self.optimizer = optimizer
180196
self.loss_function = loss_function
181197
self.inferer = SimpleInferer() if inferer is None else inferer
@@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
207223
kwargs: dict = {}
208224
else:
209225
inputs, targets, args, kwargs = batch
226+
# FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026
227+
if self.compile:
228+
inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None
229+
if isinstance(inputs, MetaTensor):
230+
warnings.warn(
231+
"Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass."
232+
)
233+
inputs, inputs_meta, inputs_applied_operations = (
234+
inputs.as_tensor(),
235+
inputs.meta,
236+
inputs.applied_operations,
237+
)
238+
if isinstance(targets, MetaTensor):
239+
targets, targets_meta, targets_applied_operations = (
240+
targets.as_tensor(),
241+
targets.meta,
242+
targets.applied_operations,
243+
)
244+
210245
# put iteration outputs into engine.state
211246
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
212247

@@ -231,6 +266,19 @@ def _compute_pred_loss():
231266
engine.state.output[Keys.LOSS].backward()
232267
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
233268
engine.optimizer.step()
269+
# copy back meta info
270+
if self.compile:
271+
if inputs_meta is not None:
272+
engine.state.output[Keys.IMAGE] = MetaTensor(
273+
inputs, meta=inputs_meta, applied_operations=inputs_applied_operations
274+
)
275+
engine.state.output[Keys.PRED] = MetaTensor(
276+
engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations
277+
)
278+
if targets_meta is not None:
279+
engine.state.output[Keys.LABEL] = MetaTensor(
280+
targets, meta=targets_meta, applied_operations=targets_applied_operations
281+
)
234282
engine.fire_event(IterationEvents.MODEL_COMPLETED)
235283

236284
return engine.state.output

0 commit comments

Comments
 (0)