Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samyamr/full precision for ZeRO Stage2 and Stage3 #1004

Merged
merged 21 commits into from
Apr 29, 2021

Conversation

samyam
Copy link
Contributor

@samyam samyam commented Apr 23, 2021

No description provided.

@stas00
Copy link
Collaborator

stas00 commented Apr 29, 2021

And just to indicate priority to this PR, we have all those bfloat16 models that won't train under fp16/mixed precision, and users want to use DeepSpeed to overcome GPU memory limitations, so they badly need this. Thank you!

@samyam samyam changed the title Samyamr/full precision for stage3 Samyamr/full precision for ZeRO Stage2 and Stage3 Apr 29, 2021
@stas00
Copy link
Collaborator

stas00 commented Apr 29, 2021

When you feel this looks good enough to test please let me know and I will start testing this branch on the transformers side. Thank you.

samyam and others added 3 commits April 29, 2021 13:24
Assert to check if param.dtype is torch.half for ZeRO3 should only happen if the model was initialized in ZeRO3 context.
@jeffra jeffra merged commit dad2642 into master Apr 29, 2021
@jeffra jeffra deleted the samyamr/full-precision-for-stage3 branch April 29, 2021 22:06
@stas00
Copy link
Collaborator

stas00 commented Apr 30, 2021

This is awesome - thank you!

I encountered only one issue:

As I am writing HF transformers tests for fp32, I found that zero.Init doesn't get dtype from the config file, I have to explicitly do:

           ds_config = deepspeed_config()
            # XXX: Fixme - we shouldn't need to figure dtype out, it should be in the config file
            dtype = torch.float16 if ds_config.get("fp16", {}).get("enabled", True) else torch.float
            with deepspeed.zero.Init(dtype=dtype, config=ds_config):
                model = cls(config, *model_args, **model_kwargs)

I thought the whole point of passing config to zero.Init is so that we don't need to manually parse the file in multiple places, we we discussing this to work:

           ds_config = deepspeed_config()
            with deepspeed.zero.Init(config=ds_config):
                model = cls(config, *model_args, **model_kwargs)

stas00 added a commit to stas00/DeepSpeed that referenced this pull request Apr 30, 2021
I'm not sure if this is the best approach but with microsoft#1004 I still have to pass `zero.Init(dtype)` because this branch never gets executed:
```
    def _set_dtype(self, ds_config, dtype):
        if ds_config is not None and dtype is None:
            _ds_config = DeepSpeedConfig(ds_config)
            self.dtype = torch.half if _ds_config.fp16_enabled else torch.float
```
@stas00 stas00 mentioned this pull request Apr 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants