-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove compile wrapper to simplify access to model attributes #5581
Changes from 3 commits
0dcf372
b9ce692
baf3319
36c90d6
72ecc85
e4b2086
2f2b855
810e65e
881b15d
9ae3d4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,7 +18,7 @@ | |||||||||||||||||||||||||||||||||||
from torch.optim.lr_scheduler import _LRScheduler | ||||||||||||||||||||||||||||||||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from typing import Callable, Dict, Union, Iterable | ||||||||||||||||||||||||||||||||||||
from typing import Callable, Dict, Union, Iterable, Any | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
import deepspeed | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -90,7 +90,7 @@ | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from .pipe.module import PipelineModule | ||||||||||||||||||||||||||||||||||||
from .utils import get_ma_status | ||||||||||||||||||||||||||||||||||||
from .compiler import CompiledModuleWrapper | ||||||||||||||||||||||||||||||||||||
from .compiler import get_backend_fn | ||||||||||||||||||||||||||||||||||||
from ..ops.adam import FusedAdam | ||||||||||||||||||||||||||||||||||||
from ..moe.sharded_moe import TopKGate, MOELayer | ||||||||||||||||||||||||||||||||||||
from ..moe.layer import MoE | ||||||||||||||||||||||||||||||||||||
|
@@ -361,8 +361,10 @@ def __init__(self, | |||||||||||||||||||||||||||||||||||
self.flatten = _flatten_dense_tensors | ||||||||||||||||||||||||||||||||||||
self.unflatten = _unflatten_dense_tensors | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if self._config.compile_config.enabled: | ||||||||||||||||||||||||||||||||||||
self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config)) | ||||||||||||||||||||||||||||||||||||
self._is_compiled = False | ||||||||||||||||||||||||||||||||||||
self._compiler_backend = None | ||||||||||||||||||||||||||||||||||||
tohtana marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
self._compile_kwargs = self._config.compile_config.kwargs | ||||||||||||||||||||||||||||||||||||
self._compiler_fn = None | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def destroy(self): | ||||||||||||||||||||||||||||||||||||
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): | ||||||||||||||||||||||||||||||||||||
|
@@ -1790,6 +1792,20 @@ def forward(self, *inputs, **kwargs): | |||||||||||||||||||||||||||||||||||
**kwargs: variable length keyword arguments | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if self._config.compile_config.enabled and not self._is_compiled: | ||||||||||||||||||||||||||||||||||||
if self._compiler_backend is None: | ||||||||||||||||||||||||||||||||||||
self._compiler_backend = get_backend_fn(self._config.compile_config.backend) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if self._compiler_fn is None: | ||||||||||||||||||||||||||||||||||||
compiled_model = torch.compile(self.module, | ||||||||||||||||||||||||||||||||||||
backend=self._compiler_backend, | ||||||||||||||||||||||||||||||||||||
**self._config.compile_config.kwargs) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
compiled_model = self._compiler_fn(self.module) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self._set_client_model(compiled_model) | ||||||||||||||||||||||||||||||||||||
self._is_compiled = True | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Reasons for the above:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is possible to achieve this with
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
On the other hand:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BacharL I think I get your points in your last comment. Can you clarify the intention of this? This seems that backend = self._compiler_fn if self._compiler_fn is not None else self._compiler_backend
self.module.compile(backend = backend, **self._compile_kwargs) @deepcharm You can pass a compiled module to deepspeed's init but it will make a slight difference. DeepSpeed sets some hooks but I thought it can't once the model is compiled. I don't think it is harmful to keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BacharL @deepcharm, I had a discussion with @tjruwase about the design. He suggested having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This suggestion seems good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BacharL I implemented the approach. Would it work for you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great!, Thank you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tohtana This approach is simple and explicit, thanks. |
||||||||||||||||||||||||||||||||||||
if self.autotuning_profile_model_info(): | ||||||||||||||||||||||||||||||||||||
ma = get_ma_status() | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
|
@@ -3600,3 +3616,64 @@ def empty_partition_cache(self): | |||||||||||||||||||||||||||||||||||
self.optimizer.empty_partition_cache() | ||||||||||||||||||||||||||||||||||||
gc.collect() | ||||||||||||||||||||||||||||||||||||
get_accelerator().empty_cache() | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def set_backend(self, backend: Union[str, Callable]): | ||||||||||||||||||||||||||||||||||||
"""Set the backend for torch.compile. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. | ||||||||||||||||||||||||||||||||||||
You can directly pass a function that works as a backend. | ||||||||||||||||||||||||||||||||||||
See also `backend` field in `CompileConfig` for more details. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
if self.is_compiled: | ||||||||||||||||||||||||||||||||||||
raise ValueError("Cannot change backend after compiling the module.") | ||||||||||||||||||||||||||||||||||||
self._compiler_backend = get_backend_fn(backend) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: | ||||||||||||||||||||||||||||||||||||
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. | ||||||||||||||||||||||||||||||||||||
You can also pass a backend name with "backend" key to change the backend. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
if self.is_compiled: | ||||||||||||||||||||||||||||||||||||
raise ValueError("Cannot change compile kwargs after compiling the module.") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if "backend" in kwargs: | ||||||||||||||||||||||||||||||||||||
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self._compile_kwargs.update(kwargs) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def set_compiler_fn(self, compiler_fn: Callable) -> None: | ||||||||||||||||||||||||||||||||||||
"""Set a function to be used for compiling the module. | ||||||||||||||||||||||||||||||||||||
This function should take a torch.nn.Module as input and return a compiled module. | ||||||||||||||||||||||||||||||||||||
Note that other compile options are ignored when a compiler_fn is set. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Example: | ||||||||||||||||||||||||||||||||||||
```python | ||||||||||||||||||||||||||||||||||||
def my_compiler_fn(module: torch.nn.Module): | ||||||||||||||||||||||||||||||||||||
... | ||||||||||||||||||||||||||||||||||||
return torch.compile(module, ...) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
engine.set_compiler_fn(my_compiler_fn) | ||||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
if self.is_compiled: | ||||||||||||||||||||||||||||||||||||
raise ValueError("Cannot change compiler_fn after compiling the module.") | ||||||||||||||||||||||||||||||||||||
self._compiler_fn = compiler_fn | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||
def is_compiled(self) -> bool: | ||||||||||||||||||||||||||||||||||||
return self._is_compiled | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||
def backend(self) -> Union[str, Callable]: | ||||||||||||||||||||||||||||||||||||
return self._backend | ||||||||||||||||||||||||||||||||||||
tohtana marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||
def torch_compile_kwargs(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||||
return self._compile_kwargs | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||
def compiler_fn(self) -> Union[Callable, None]: | ||||||||||||||||||||||||||||||||||||
return self._compiler_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.