Skip to content

Commit

Permalink
[INF] Enable torch compile for inference (#5612)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 9, 2024
1 parent 2105976 commit 7b1ea22
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import torch
import time
import os
import deepspeed
from deepspeed import comm as dist
from deepspeed.utils.logging import log_dist

from torch.nn.modules import Module
from packaging import version as pkg_version
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(self, model, config):

# Check if local CUDA graphs can be created in replacement modules
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
self._is_compiled = False

def destroy(self):
# Have to import here because inference_module is a global, but python
Expand Down Expand Up @@ -634,3 +637,22 @@ def _generate(self, *inputs, **kwargs):
)

return self.module.generate(*inputs, **kwargs)

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
"""
Compile the module using the specified backend and kwargs.
"""
if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

if self._is_compiled:
return

# Avoid graph breaks
deepspeed.utils.nvtx.enable_nvtx = False
self.module.compile(backend=backend, **compile_kwargs)
self._is_compiled = True

@property
def is_compiled(self) -> bool:
return self._is_compiled

0 comments on commit 7b1ea22

Please sign in to comment.