Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add ApexOptimWrapper #742

Merged
merged 15 commits into from
Feb 6, 2023
Merged

Conversation

xcnick
Copy link
Contributor

@xcnick xcnick commented Nov 18, 2022

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

Add ApexOptimWrapper for mmengine.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@codecov
Copy link

codecov bot commented Nov 18, 2022

Codecov Report

❗ No coverage uploaded for pull request base (main@6dc1d70). Click here to learn what that means.
Patch has no changes to coverable lines.

❗ Current head 1b5882c differs from pull request most recent head 7a825d0. Consider uploading reports for the commit 7a825d0 to get more accurate results

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #742   +/-   ##
=======================================
  Coverage        ?   77.90%           
=======================================
  Files           ?      133           
  Lines           ?    10086           
  Branches        ?     2010           
=======================================
  Hits            ?     7857           
  Misses          ?     1888           
  Partials        ?      341           
Flag Coverage Δ
unittests 77.90% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thanks for your contribution. It seems the current implementation does not use apex.amp.initialize to prepare the model and optimizer, I know there exits some limitations that make it hard to be implemented In MMEngine, and I want to discuss how we could support ApexOptimWrapper better.

1.If we don't call initialize, will mix-precision training work?

2.As the official document of apex describes:

image

We need to call amp.initialize() before wrap the model with ddp wrapper and optimizer. However, there is a confliction, PyTorch want suggest we given the parameters of ddp model to optimizer in tutorial

image

Actually, it is reasonable, since distributed wrapper like FSDP will overwrite the original parameters and we have to pass its parameters to optimizer after wrapping the model. Alright, it means that when we got the optimizer, the model must have been wrapped by DDP, which conflicts with the principle of apex.

In MMEngine, OptimWrapper.optim_context can get the ddp-model and optimizer, I'm not sure if we can use amp.initialize there (Maybe initialize the model in ddp wrapper in place?).

3.We need to consider how to resume the optimizer.

@xcnick
Copy link
Contributor Author

xcnick commented Nov 18, 2022

  1. apex.amp.initialize must be called to use mix-precision training.
  2. I think (not sure) apex.amp.initialize is independent of DDP, so the pipeline is:
my_model = Model()
my_opt = SGD(...)
my_model, my_opt =  apex.amp.initialize(my_model, my_opt, opt_level='O1', loss_scale=...)
ddp_model = DDP(my_model)
  1. ApexOptimWrapper use the same load_state_dict() and state_dict() method as OptimWrapper.

@HAOCHENYE
Copy link
Collaborator

  1. apex.amp.initialize must be called to use mix-precision training.
  2. I think (not sure) apex.amp.initialize is independent of DDP, so the pipeline is:
my_model = Model()
my_opt = SGD(...)
my_model, my_opt =  apex.amp.initialize(my_model, my_opt, opt_level='O1', loss_scale=...)
ddp_model = DDP(my_model)
  1. ApexOptimWrapper use the same load_state_dict() and state_dict() method as OptimWrapper.

Ahhh, If we want to use ApexOptimWrapper independent of Runner, I think the current implementation is almost enough. But if we want to use ApexOptimWrapper in Runner, what should I do?

The key problem is that where should we call apex.amp.initialize to enable mix-precision training. It seems ApexOptimWrapper now does not call apex.amp.initialize. Does it mean that if we simply replace the OptimWrapper with ApexOptimWrapper, the mixed-precision training will not be enabled?

The second problem is the compatibility of ddp-training, as mentioned above:

image

amp.initialize require that the model should not be wrapper with ddp-wrapper. Therefore, how could we do this in Runner ?

  1. load or resume

In the apex tutorial

image

apex.amp has their own way t save or load checkpoint, we should take it into consideration.

@HAOCHENYE HAOCHENYE added this to the 0.5.0 milestone Nov 20, 2022
@nijkah
Copy link
Contributor

nijkah commented Nov 21, 2022

Hi, I have an interest in this feature.
Related to #627, we have our own implementation to support DeepSpeed in MMEngine and willingness to post the PR.

One of the crucial changes to support DeepSpeed is that deepspeed inlcude deepspeed.initialize(model=model, optimizer=optimizer) which is a very similar interface as apex.amp.initialize.

As mentioned by @HAOCHENYE , these initialize interfaces should be called before Runner.wrap_model and it is not possible with OptimWrapper.optim_context.

So we had to write the new DeepSpeedRunner to support it.
I think it is a quite key feature to several related frameworks such as Colossal-AI.

Can we discuss this? (in this PR, the new Issue or discussion board)

@xcnick
Copy link
Contributor Author

xcnick commented Nov 21, 2022

@HAOCHENYE
As you metioned, the key problem is that where should we call apex.amp.initialize to enable mix-precision training.
3.load and resume, also depends on where apex.amp.initialize is called. If ApexOptimWrapper does not call apex.amp.initialize, just keep the existing implementation, inherit from OptimWrapper, amp.state_dict() will be handled somewhere outside, such as ApexRunner.

@HAOCHENYE
Copy link
Collaborator

@nijkah
Good suggestions!! Recently we are (mainly explored by @C1rN09) also committed to figuring out a proper design to support training with DeepSpeed or Colossal AI. I think it could be a good idea to create a new discussion to discuss this topic Specially
😆 !

@xcnick
Completely agree! But could we call apex.amp.initialize in some tricky way? Actually, _initialize in apex

https://github.com/NVIDIA/apex/blob/082f999a6e18a3d02306e27482cc7486dab71a50/apex/amp/_initialize.py#L145

only patch the forward and register some custom hooks into forward, I'm not sure it will work when we call amp.initialize in optim_context like:

# suppose in ddp training
amp.initialize(model.module, self.optimizer)

If it works, I think it could be a temporary solution to support ApexOptimWrapper. When we find a more elegant way to support Colossal AI or DeepSpeed, which also applies to ApexOptimWrapper, we could switch to the new design.

@C1rN09
Copy link
Collaborator

C1rN09 commented Nov 22, 2022

Can we discuss this? (in this PR, the new Issue or discussion board)

Hi, @nijkah ! I've created a discussion thread. You can paste these comments there and we can discuss on it!

And @xcnick since this PR is related, your opinions and ideas are also welcome!

Comment on lines 39 to 90
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs):
super().__init__(**kwargs)
Copy link
Contributor

@nijkah nijkah Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs):
super().__init__(**kwargs)
def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs):
assert apex_amp is not None, \
'Apex is not installed. Please check https://github.com/NVIDIA/apex#linux.',
super().__init__(**kwargs)

The assertion logic could be added here.

Comment on lines 12 to 13
except ImportError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except ImportError:
pass
except ImportError:
apex_amp = None

@xcnick
Copy link
Contributor Author

xcnick commented Nov 23, 2022

@nijkah Thanks for your suggestion.

Comment on lines 101 to 165
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
with super().optim_context(model):
model, self.optimizer = apex_amp.initialize(
model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a question. When I checked the Apex documentation, I couldn't find a description to handle it like this.
Isn't there any issue in training when handling it like this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually what I concerned, I took a cursory look at some of implementation for apex.initialize:

https://github.com/NVIDIA/apex/blob/082f999a6e18a3d02306e27482cc7486dab71a50/apex/amp/_initialize.py#L145

It seems initialize only patches the forward and registers some custom hooks into forward. I'm not sure about doing this on model.module will work or not, we need to check the optimization for saving memory and accelerating training when ApexOptimWrapper is enabled.

@xcnick could you provide the comparison of nvidia-smi when ApexOptimWrapper is enabled or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, there is something wrong with current implementation.

In examples\examples/distributed_training.py, compare the following two options:

batch_size=1024,
...
optim_wrapper=dict(
        type='ApexOptimWrapper', opt_level='O1', loss_scale=10,
        optimizer=dict(type=Adam, lr=0.001)),
...

and

batch_size=1024,
...
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),

result:

optim_wrapper nvidia-smi memory in log file
ApexOptimWrapper 11GB 5574
original OptimWrapper 5.2GB 2249

It seems that calling amp.initialize in optim_context may not be the correct way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may be able to participate in the related discussion if you are interested.
#749 (comment)

Copy link
Collaborator

@HAOCHENYE HAOCHENYE Nov 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emmm, the conclusion makes me confused. I've tried to train mmdet with retinanet-r50 based on this PR for figuring out why it does not work. However, the result of nvidia-smi is:

image

image

It seems that the ApexOptimWrapper has worked.... I also test the examples/examples/distributed_training.py, the result is the same.

So... what happened 🤣

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll use ApexOptimWrapper based on this PR to train MMDet these days, and check the training speed and accuracy. Do you have any suggestions or comments about the current implementations @xcnick @nijkah?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether the this implementation can work correctly, so good luck ^^

In addition, the to method of BaseModel in base_model.py needs to be modified for compatibility in O0 and O3 modes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, the to method of BaseModel in base_model.py needs to be modified for compatibility in O0 and O3 modes.

Do you mean the modifications in #783 ? If it is not enough, feel free to leave a comment!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job!

Copy link
Collaborator

@HAOCHENYE HAOCHENYE Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my late reply. To give a conclusion first: the current implementation seems OK. I used ApexOptimWrapper to train the atss in MMDet, and the loss was able to converge normally, and it will take some time to verify the accuracy.

The reason for the late reply is that I spent some time trying to train RetinaNet (O1) based on ApexOptimWrapper. However, it turned out that neither AmpOptimWrapper nor ApexOptimWrapper can train RetinaNet (loss is nan) normally, which is determined by the model's own characteristics and has nothing to do with the implementation of ApexOptimWrapper

@CLAassistant
Copy link

CLAassistant commented Dec 14, 2022

CLA assistant check
All committers have signed the CLA.

Copy link
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi~ is there any progress? Besides, CLA should be signed again 🤣 .

Comment on lines 98 to 99
if hasattr(self.optimizer, '_amp_stash'):
yield
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(self.optimizer, '_amp_stash'):
yield
if hasattr(self.optimizer, '_amp_stash'):
with super().optim_context(model):
yield

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean that the initialized optimizer will have the _amp_stash attribute? Maybe we need to add some comments here.

Comment on lines 26 to 27
``ApexOptimWrapper`` requires
[nvidia apex](https://github.com/NVIDIA/apex).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``ApexOptimWrapper`` requires
[nvidia apex](https://github.com/NVIDIA/apex).
``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_

Comment on lines 29 to 65
Args:

**kwargs: Keyword arguments passed to OptimWrapper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing arguments description here

``accumulative_counts``.
"""

def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing type hint

self.opt_level = opt_level
self.loss_scale = loss_scale

def backward(self, loss: torch.Tensor, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def backward(self, loss: torch.Tensor, **kwargs):
def backward(self, loss: torch.Tensor, **kwargs) -> None:

state_dict['apex_amp'] = apex_amp.state_dict()
return state_dict

def load_state_dict(self, state_dict: dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def load_state_dict(self, state_dict: dict):
def load_state_dict(self, state_dict: dict) -> None:

else:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
with super().optim_context(model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super().optim_context should be called each iteration for avoiding necessary gradient accumulation during training,

@xcnick
Copy link
Contributor Author

xcnick commented Jan 9, 2023

Hi~ is there any progress? Besides, CLA should be signed again 🤣 .

Sorry for the late response, I updated the code.

@zhouzaida
Copy link
Collaborator

Please add ApexOptimWrapper in https://github.com/open-mmlab/mmengine/blob/main/docs/en/api/optim.rst#optimizer and https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/api/optim.rst#optimizer

# when a given optimizer be passed through apex_amp.initialize,
# the "_amp_stash" property will be added
if hasattr(self.optimizer, '_amp_stash'):
yield
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that here misses a return, otherwise apex_amp.initialize will be called each iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that yield is followed by return, but it make sense in contextmanager.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify the logic.

if not hasattr(self.optimizer, '_amp_stash'):
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module
    model, self.optimizer = apex_amp.initialize(xxx)
yield

apex_optim_wrapper = ApexOptimWrapper(
optimizer=optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
apex_optim_wrapper.optimizer.param_groups = MagicMock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes we use MagicMock to assert some function or method has been called, but why do we need mock the param_groups here?

mmengine/optim/optimizer/apex_optimizer_wrapper.py Outdated Show resolved Hide resolved
mmengine/optim/optimizer/apex_optimizer_wrapper.py Outdated Show resolved Hide resolved
@zhouzaida
Copy link
Collaborator

Hi @xcnick , thanks for your contribution. This PR can be merged after resolving the above final comments and validating it in your local machine.

zhouzaida
zhouzaida previously approved these changes Feb 3, 2023
@HAOCHENYE
Copy link
Collaborator

Hi! as the official document says:

image

The current implementation of ApexOptimWrapper.load_state_dict will raise an error for the lack of initialization of apex. A possibly workaround could be not loading state_dict in ApexOptimWrapper.load_state_dict, but only saving the state_dict as an attribute, then calling apex.amp.load_state_dict in optim_context

zhouzaida
zhouzaida previously approved these changes Feb 5, 2023
HAOCHENYE
HAOCHENYE previously approved these changes Feb 6, 2023
@zhouzaida zhouzaida dismissed stale reviews from HAOCHENYE and themself via 7a825d0 February 6, 2023 07:11
@zhouzaida zhouzaida merged commit e35ed5f into open-mmlab:main Feb 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants