-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add ApexOptimWrapper #742
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #742 +/- ##
=======================================
Coverage ? 77.90%
=======================================
Files ? 133
Lines ? 10086
Branches ? 2010
=======================================
Hits ? 7857
Misses ? 1888
Partials ? 341
Flags with carried forward coverage won't be shown. Click here to find out more. Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Thanks for your contribution. It seems the current implementation does not use apex.amp.initialize
to prepare the model and optimizer, I know there exits some limitations that make it hard to be implemented In MMEngine, and I want to discuss how we could support ApexOptimWrapper
better.
1.If we don't call initialize, will mix-precision training work?
2.As the official document of apex describes:
We need to call amp.initialize() before wrap the model with ddp wrapper and optimizer. However, there is a confliction, PyTorch want suggest we given the parameters of ddp model to optimizer in tutorial
Actually, it is reasonable, since distributed wrapper like FSDP will overwrite the original parameters and we have to pass its parameters to optimizer after wrapping the model. Alright, it means that when we got the optimizer, the model must have been wrapped by DDP, which conflicts with the principle of apex.
In MMEngine, OptimWrapper.optim_context
can get the ddp-model and optimizer, I'm not sure if we can use amp.initialize
there (Maybe initialize the model in ddp wrapper in place?).
3.We need to consider how to resume the optimizer.
my_model = Model()
my_opt = SGD(...)
my_model, my_opt = apex.amp.initialize(my_model, my_opt, opt_level='O1', loss_scale=...)
ddp_model = DDP(my_model)
|
Ahhh, If we want to use ApexOptimWrapper independent of The key problem is that where should we call The second problem is the compatibility of ddp-training, as mentioned above:
In the apex tutorial apex.amp has their own way t save or load checkpoint, we should take it into consideration. |
Hi, I have an interest in this feature. One of the crucial changes to support As mentioned by @HAOCHENYE , these So we had to write the new Can we discuss this? (in this PR, the new Issue or discussion board) |
@HAOCHENYE |
@nijkah @xcnick only patch the # suppose in ddp training
amp.initialize(model.module, self.optimizer) If it works, I think it could be a temporary solution to support |
Hi, @nijkah ! I've created a discussion thread. You can paste these comments there and we can discuss on it! And @xcnick since this PR is related, your opinions and ideas are also welcome! |
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): | ||
super().__init__(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): | |
super().__init__(**kwargs) | |
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): | |
assert apex_amp is not None, \ | |
'Apex is not installed. Please check https://github.com/NVIDIA/apex#linux.', | |
super().__init__(**kwargs) |
The assertion logic could be added here.
except ImportError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
except ImportError: | |
pass | |
except ImportError: | |
apex_amp = None |
@nijkah Thanks for your suggestion. |
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||
model = model.module | ||
with super().optim_context(model): | ||
model, self.optimizer = apex_amp.initialize( | ||
model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a question. When I checked the Apex documentation, I couldn't find a description to handle it like this.
Isn't there any issue in training when handling it like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually what I concerned, I took a cursory look at some of implementation for apex.initialize
:
It seems initialize
only patches the forward and registers some custom hooks into forward. I'm not sure about doing this on model.module
will work or not, we need to check the optimization for saving memory and accelerating training when ApexOptimWrapper
is enabled.
@xcnick could you provide the comparison of nvidia-smi
when ApexOptimWrapper
is enabled or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, there is something wrong with current implementation.
In examples\examples/distributed_training.py
, compare the following two options:
batch_size=1024,
...
optim_wrapper=dict(
type='ApexOptimWrapper', opt_level='O1', loss_scale=10,
optimizer=dict(type=Adam, lr=0.001)),
...
and
batch_size=1024,
...
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
result:
optim_wrapper | nvidia-smi | memory in log file |
---|---|---|
ApexOptimWrapper | 11GB | 5574 |
original OptimWrapper | 5.2GB | 2249 |
It seems that calling amp.initialize
in optim_context
may not be the correct way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may be able to participate in the related discussion if you are interested.
#749 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Emmm, the conclusion makes me confused. I've tried to train mmdet with retinanet-r50 based on this PR for figuring out why it does not work. However, the result of nvidia-smi is:
It seems that the ApexOptimWrapper
has worked.... I also test the examples/examples/distributed_training.py
, the result is the same.
So... what happened 🤣
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether the this implementation can work correctly, so good luck ^^
In addition, the to
method of BaseModel
in base_model.py
needs to be modified for compatibility in O0
and O3
modes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, the
to
method ofBaseModel
inbase_model.py
needs to be modified for compatibility inO0
andO3
modes.
Do you mean the modifications in #783 ? If it is not enough, feel free to leave a comment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my late reply. To give a conclusion first: the current implementation seems OK. I used ApexOptimWrapper
to train the atss
in MMDet, and the loss was able to converge normally, and it will take some time to verify the accuracy.
The reason for the late reply is that I spent some time trying to train RetinaNet (O1) based on ApexOptimWrapper
. However, it turned out that neither AmpOptimWrapper
nor ApexOptimWrapper
can train RetinaNet (loss is nan
) normally, which is determined by the model's own characteristics and has nothing to do with the implementation of ApexOptimWrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi~ is there any progress? Besides, CLA should be signed again 🤣 .
if hasattr(self.optimizer, '_amp_stash'): | ||
yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(self.optimizer, '_amp_stash'): | |
yield | |
if hasattr(self.optimizer, '_amp_stash'): | |
with super().optim_context(model): | |
yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean that the initialized optimizer will have the _amp_stash
attribute? Maybe we need to add some comments here.
``ApexOptimWrapper`` requires | ||
[nvidia apex](https://github.com/NVIDIA/apex). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``ApexOptimWrapper`` requires | |
[nvidia apex](https://github.com/NVIDIA/apex). | |
``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_ |
Args: | ||
|
||
**kwargs: Keyword arguments passed to OptimWrapper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing arguments description here
``accumulative_counts``. | ||
""" | ||
|
||
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing type hint
self.opt_level = opt_level | ||
self.loss_scale = loss_scale | ||
|
||
def backward(self, loss: torch.Tensor, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def backward(self, loss: torch.Tensor, **kwargs): | |
def backward(self, loss: torch.Tensor, **kwargs) -> None: |
state_dict['apex_amp'] = apex_amp.state_dict() | ||
return state_dict | ||
|
||
def load_state_dict(self, state_dict: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def load_state_dict(self, state_dict: dict): | |
def load_state_dict(self, state_dict: dict) -> None: |
else: | ||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||
model = model.module | ||
with super().optim_context(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super().optim_context
should be called each iteration for avoiding necessary gradient accumulation during training,
Sorry for the late response, I updated the code. |
Please add |
# when a given optimizer be passed through apex_amp.initialize, | ||
# the "_amp_stash" property will be added | ||
if hasattr(self.optimizer, '_amp_stash'): | ||
yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that here misses a return
, otherwise apex_amp.initialize
will be called each iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's weird that yield
is followed by return
, but it make sense in contextmanager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simplify the logic.
if not hasattr(self.optimizer, '_amp_stash'):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
model, self.optimizer = apex_amp.initialize(xxx)
yield
apex_optim_wrapper = ApexOptimWrapper( | ||
optimizer=optimizer, opt_level='O1', loss_scale=1) | ||
with apex_optim_wrapper.optim_context(self.model): | ||
apex_optim_wrapper.optimizer.param_groups = MagicMock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes we use MagicMock to assert some function or method has been called, but why do we need mock the param_groups
here?
Hi @xcnick , thanks for your contribution. This PR can be merged after resolving the above final comments and validating it in your local machine. |
Hi! as the official document says: The current implementation of |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Add ApexOptimWrapper for mmengine.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist