Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddp fixes for training #22874

Merged
merged 1 commit into from
Apr 21, 2023
Merged

ddp fixes for training #22874

merged 1 commit into from
Apr 21, 2023

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Apr 19, 2023

What does this PR do?

While trying to train Stable LM or even Llama, I ran into a couple of issues with multi-gpu and DDP.

I've added a check to skip this for this case since torch doesn't support it: see https://github.com/pytorch/pytorch/blob/main/torch/nn/parallel/distributed.py#L686-L694

  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1633, in train                                                                                                                                   
    return inner_training_loop(                                                                                                                                                                                                
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1720, in _inner_training_loop                                                                                                                    
    model = self._wrap_model(self.model_wrapped)                                                                                                                                                                               
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1545, in _wrap_model                                                                                                                             
    model = nn.parallel.DistributedDataParallel(                                                                                                                                                                               
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 571, in __init__                                                                                                                            self._log_and_throw(                                                                                                                                                                                                       
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 769, in _log_and_throw                                                                                                                      raise err_type(err_msg)                                                                                                                                                                                                    
RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.               

Added another check for the method no_sync

  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1634, in train                                                                                                                                   
    return inner_training_loop(                                                                                
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1900, in _inner_training_loop                                                                                                                        with model.no_sync():                                                                                                                                                                                                      
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__                                                                                                                          
    raise AttributeError("'{}' object has no attribute '{}'".format(                                                                                                                                                           AttributeError: 'GPTNeoXForCausalLM' object has no attribute 'no_sync'      

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 19, 2023

The documentation is not available anymore as the PR was closed or merged.

@winglian winglian changed the title ddp fixes for stable lm training ddp fixes for training Apr 21, 2023
@muellerzr
Copy link
Contributor

cc @sgugger, solution is exactly what we have in Accelerate, and would be a good way to keep it working until the Accelerate integration is fully finished :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@sgugger sgugger merged commit d00997e into huggingface:main Apr 21, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
ddp fixes for stable lm training
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants