Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove compile wrapper to simplify access to model attributes #5581

Merged
merged 10 commits into from
Jun 17, 2024
84 changes: 0 additions & 84 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,87 +81,3 @@ def validate_enabled(cls, field_value, values):
if field_value and not is_compile_supported():
raise ValueError("torch.compile is not supported on this version of PyTorch.")
return field_value


def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):

class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = {k: module.__dict__[k] for k in module.__dict__ if not k in self.__class__.__dict__}

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.

Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.

Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.

Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)

engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped,
backend=self._backend,
**self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn

return wrapper(mod, compile_config)
85 changes: 81 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from typing import Callable, Dict, Union, Iterable
from typing import Callable, Dict, Union, Iterable, Any

import deepspeed

Expand Down Expand Up @@ -90,7 +90,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import CompiledModuleWrapper
from .compiler import get_backend_fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .compiler import get_backend_fn
from .compiler import get_backend_fn, is_compile_supported

from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
Expand Down Expand Up @@ -361,8 +361,10 @@ def __init__(self,
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors

if self._config.compile_config.enabled:
self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config))
self._is_compiled = False
self._compiler_backend = None
tohtana marked this conversation as resolved.
Show resolved Hide resolved
self._compile_kwargs = self._config.compile_config.kwargs
self._compiler_fn = None

def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
Expand Down Expand Up @@ -1790,6 +1792,20 @@ def forward(self, *inputs, **kwargs):
**kwargs: variable length keyword arguments
"""

if self._config.compile_config.enabled and not self._is_compiled:
if self._compiler_backend is None:
self._compiler_backend = get_backend_fn(self._config.compile_config.backend)

if self._compiler_fn is None:
compiled_model = torch.compile(self.module,
backend=self._compiler_backend,
**self._config.compile_config.kwargs)
else:
compiled_model = self._compiler_fn(self.module)

self._set_client_model(compiled_model)
self._is_compiled = True

Copy link
Collaborator

@BacharL BacharL May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._config.compile_config.enabled and not self._is_compiled:
if self._compiler_backend is None:
self._compiler_backend = get_backend_fn(self._config.compile_config.backend)
if self._compiler_fn is None:
compiled_model = torch.compile(self.module,
backend=self._compiler_backend,
**self._config.compile_config.kwargs)
else:
compiled_model = self._compiler_fn(self.module)
self._set_client_model(compiled_model)
self._is_compiled = True
if self._config.compile_config.enabled not self._is_compiled and is_compile_supported():
backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend
self.module.compile(backend = backend, **self._compile_kwargs)
self._is_compiled = True

Reasons for the above:

  1. Allows passing user defined kwargs even when the user provides custom compiler_fn
  2. Type of self.module does not changes since torch.nn.Module.compile is called instead of torch.compile. The module is compiled in-place.
    I used to fail on this line. It should pass now as we only compile on forward and not in init, nevertheless we should consider keeping the type of self.module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.nn.Module.compile is a great idea. Thank you for this suggestion.

compiler_fn is not a backend. The intention of compiler_fn is to enable something that can't be done just by setting backend and kwargs for torch.compile (e.g. compiling part of the model). We need to fix that part carefully to make it consistent with in-place compile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible to achieve this with compiler_fn. torch.nn.Module.compile just calls this function for compilation.
example in this custom function one can iterate over module and call compile() inplace for every part to compile

    def custom_compiler_fn(module: torch.nn.Module, example_inputs):
        global custom_compler_fn_called
        custom_compler_fn_called = True
        module.l1.compile(backend=get_accelerator().get_compile_backend())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reasons for the above:

  1. Allows passing user defined kwargs even when the user provides custom compiler_fn
  2. Type of self.module does not changes since torch.nn.Module.compile is called instead of torch.compile. The module is compiled in-place.
    I used to fail on this line. It should pass now as we only compile on forward and not in init, nevertheless we should consider keeping the type of self.module.

On the other hand:

  1. The model returned by torch.compile(user_model) is actually a smart wrapper of the user_model. Torch keeps the user_model inside, such that user can add new attributes or change the existing ones in their original model.
  2. User may not want their module be compiled by DS.
    compiled_model = torch.compile(self.module,.. solves that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL I think I get your points in your last comment. Can you clarify the intention of this? This seems that _compiler_fn has no difference with _compiler_backend.

            backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend
	    self.module.compile(backend = backend, **self._compile_kwargs)

@deepcharm You can pass a compiled module to deepspeed's init but it will make a slight difference. DeepSpeed sets some hooks but I thought it can't once the model is compiled. I don't think it is harmful to keep compiler_fn while such use cases might not be very popular.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL @deepcharm, I had a discussion with @tjruwase about the design.

He suggested having compile API of DeepSpeed instead of running compilation at the first forward pass. The behavior would be easier for users to understand. In this design, we don't need ds config for compile any more. At least for now, it will be a simple wrapper of engine.module.compile().
Do you have any thought on this? I think I can briefly test the feasibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion seems good.
Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BacharL I implemented the approach. Would it work for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!, Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana This approach is simple and explicit, thanks.

if self.autotuning_profile_model_info():
ma = get_ma_status()
else:
Expand Down Expand Up @@ -3600,3 +3616,64 @@ def empty_partition_cache(self):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.

Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
if self.is_compiled:
raise ValueError("Cannot change backend after compiling the module.")
self._compiler_backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.

Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""
if self.is_compiled:
raise ValueError("Cannot change compile kwargs after compiling the module.")

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")

self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.

Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)

engine.set_compiler_fn(my_compiler_fn)
```
"""
if self.is_compiled:
raise ValueError("Cannot change compiler_fn after compiling the module.")
self._compiler_fn = compiler_fn

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend
tohtana marked this conversation as resolved.
Show resolved Hide resolved

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn
3 changes: 3 additions & 0 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,13 @@ def _launch_non_daemonic_procs(self, num_procs):
master_port = get_master_port()
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
processes = []
prev_start_method = mp.get_start_method()
mp.set_start_method('spawn', force=True)
for local_rank in range(num_procs):
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
p.start()
processes.append(p)
mp.set_start_method(prev_start_method, force=True)

# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
Expand Down
Loading