Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ZeRO Infinity] Allow Init to take a dict for the deepspeed config #983

Merged
merged 6 commits into from
Apr 20, 2021
Merged

[ZeRO Infinity] Allow Init to take a dict for the deepspeed config #983

merged 6 commits into from
Apr 20, 2021

Conversation

SeanNaren
Copy link
Contributor

Hey guys!

I needed this fix in order to run ZeRO Infinity with Lightning, as the deepspeed config is already loaded into memory as a dictionary. When passing to Init and remote_device=nvme, passing the config as a dictionary would crash without this fix!

@@ -521,8 +521,10 @@ def write_config(self, filename):
class DeepSpeedConfig(object):
def __init__(self, json_file, mpu=None, param_dict=None):
super(DeepSpeedConfig, self).__init__()

if param_dict is None:
if isinstance(json_file, dict):
Copy link
Contributor

@tjruwase tjruwase Apr 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think typing json_file as a dict is confusing, rather we should pass dict as param_dict. So ideally this file should be unchanged for this PR. Instead the following should be modified to optionally pass param_dict as appropriate.
Call 1
Call 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner, just made the change!

@tjruwase tjruwase merged commit 3525102 into microsoft:master Apr 20, 2021
@SeanNaren SeanNaren deleted the fix/config branch April 20, 2021 22:16
sdtblck added a commit to EleutherAI/DeeperSpeed that referenced this pull request Apr 22, 2021
* test sparse self_attn fix

* [WarmupDecayLR] fix log(0) & 1/log(1) bugs (microsoft#772)

* fix log(0) & 1/log(1) bugs

* simplify

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>

* bump to v0.3.12

* Bug fix: Remove client optimizer param_group list item that does not have 'params' (microsoft#827)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [doc] pipeline doc typos/improvements (microsoft#659)

Admin merging for pure-doc PR that does not trigger build.

* Samyamr/inference hook fix (microsoft#851)

* Fix mis-aligned-grad

When a parameter is not divisible by world size, the partitioned gradients are mis-aligned due to incorrect padding handling. This PR should fix for that.

* Formatting fix

* Adding static_scale test back for Z3, and also changing hidden size to be not divisile by world_size

* also removing alignment from flat fp16 buffers

* Testing for hidden dim alignment

* inference hook fix

* Update stage3.py

* formatting

* [bug-fix] move params to gpu if offload params is turned off

Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* ZeRO Stage 2: Clear reduced gradients (microsoft#856)

* Ensure gradients of other partitions are cleared after reduction

* Remove redundant code

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [runner/launch] propagate the error (microsoft#854)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* docs: minor spelling tweaks (microsoft#858)

* Allow args to be optional in deepspeed.initialize (microsoft#825)

* Fix ZeRO3 save_checkpoint (microsoft#857)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Make config objects json serializable (microsoft#862)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* bump version 0.3.13

* 1-bit Adam v2 (microsoft#817)

Authors: @awan-10 @conglongli @samyam @jeffra

What's new:

NCCL-based implementation which provides better performance and usability compared to the MPI-based implementation.
Add support to momentum masks for those parameters with constant zero gradients during training.
Bug fixes (e.g., microsoft#813).

* NCCL-based 1-bit Adam + Code Refactor for Comm. Backends (microsoft#594)

* NCCL based 1-bit Implementation + Refactor to add communication backends (microsoft#593)

* add nccl 1-bit optim.

* temporary commit to save stuff.

* Use dist collectives instead of mpi routines.

* remove old code for comm.

* Fix bugs. still does not work.

* modify to test the nccl side code path

* Initial gather impl. Works intra-node.

* Updates to comm. phase 2. nccl comm. passed the tests.

* refactor code to introduce nccl/mpi as backends for onebit adam.

* Refactor updates to test/engine.

* Fix compile/runtime errors.

* simplify support for nccl/mpi backends.

* Add missign file

* Add compression backend in constructor. Revert later.

* modify test with some perf counting.

* Implement a true non-blocking gather for nccl side.

* Revert "Add compression backend in constructor. Revert later."

This reverts commit df8c40d.

* improve the 1-bit adam test.

* Refactor comm. and compression backend in 1-bit adam.

* Fix the test.

* Fix runtime errors and typos in nccl backend

* fix mpi backend. modify tests.

* modify nccl perf test.

* fix mpi side errors.

* Add an mpi perf test

* Sync DSE.

* Remove old collectives file.

* Undo a typo.

* Graceful failure for torch versions that don't support nccl pt2pt.

* Revert "Merge branch 'master' into staging-1bit-nccl-v2"

This reverts commit 7840085, reversing
changes made to a6dba72.

* Revert "Revert "Merge branch 'master' into staging-1bit-nccl-v2""

This reverts commit 6dbdd98.

* comm optimization + 1-bit lamb

* Saving/debugging commit.

* finalizing 1-bit lamb

* finalizing 1-bit lamb

* add momentum mask and chkpt handling for 1-bit adam

* Cleanup and modify nccl test to be runnable with deepspeed launcher.

* Fix format.

* fix formatting again.

* make test runnable without mpi4py

* Add dist.alltoall and dist.allgather instead of custom functions.

* remove debug prints.

* formatting and renaming

* renaming

* renaming

* add unit test, fix existing tests

* skip unit test when torch < 1.8

* revert 1-bit lamb

* flatten momentum when dimension is more than 1

* add warning message for 1-bit adam under fp32

* improve version check

* add fp32 test

* 1-bit adam doc

* fix file name

* doc fix

* torch 1.8 is released

* doc fix

* fix tests

* update news

* add doc for momentum mask

* fix checkpoing handling, add unit test

* checkpoint handling doc

* doc final cleanup

* bump dates

* update tests

* url change

* doc fix

* fix test

* doc update

Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* consistent checkpoint filenaming (microsoft#865)

* consistent checkpoint filenaming

* backward compatible rename

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* [doc] launcher (microsoft#868)

As discussed in microsoft#662 this PR modifies the doc:
* explains what to use instead of CUDA_VISIBLE_DEVICES
* puts the `--hostfile` cl arg in the correct place in the invocation script

Fixes: microsoft#662

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [doc] pipeline (microsoft#888)

* [doc] pipeline

As @g-karthik flagged in microsoft#659 (comment) my previous correction PR had one sentence that said the wrong thing. So this PR attempts to rectify that. 

Thank you!

* tweak

* [debug utils] see_memory_usage fixes (microsoft#890)

* see_memory_usage fixes

* didn't expect pt-1.2

* fix the order of things

* fix the order of things

* full fp32 weights reconstruction for zero 2+3 (microsoft#892)

* save_fp16_model consolidated for zero3 (microsoft#893)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Fix zero stage2 cpu_offload when some model trainable parameters skipped in training (microsoft#861)

* Fix zero stage2 cpu_offload when some model trainable parameters skipped in training, as in microsoft#707

As some model trainable parameters skipped in training,
their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, 
so they have no norm_for_param_grads

* Trim space

* Trim space

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* mlperf attn initial commit

* update kramdown (microsoft#901)

security alert related to older kramdown version

* update backward api doc (microsoft#903)

* Bump kramdown from 2.3.0 to 2.3.1 in /docs (microsoft#905)

Bumps [kramdown](https://github.com/gettalong/kramdown) from 2.3.0 to 2.3.1.
- [Release notes](https://github.com/gettalong/kramdown/releases)
- [Changelog](https://github.com/gettalong/kramdown/blob/master/doc/news.page)
- [Commits](https://github.com/gettalong/kramdown/commits)

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* We're hiring! + integration posts

* [website] We're hiring! + integration posts

* [website] we're hiring!

* zero.Init() clarification (microsoft#880)

* zero.Init() clarification

clarify that if `model.half()` can't fit into gpu memory `zero.Init()` is a must.

this proposal is via @samyam's clarification shared elsewhere.

Thank you.

* style

* add clarity

* style

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* disable pipe test (microsoft#915)

This test has been giving us trouble for a bit, seeing nondeterministic failures, skipping for now to not break out CI. Need to revisit soon though.

* Add link to AML examples. (microsoft#916)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* add inference_batch fn

* Add space in help string (microsoft#926)

* Fix for fragmented linear inputs in ZeRO 3 Linear layers where reshap… (microsoft#881)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [zero3] GatheredParameters can now handle a list of params (microsoft#884)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* fix cpu_adam memory leak on deepspeed re-use in the same process (microsoft#896)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [benchmarks] flatten/unflatten benchmarks (microsoft#919)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* improved readability + typos (microsoft#895)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [zero doc] fix misspelled param (microsoft#878)

We really really really need those params to be validated...

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Samyamr/stage 3 skip modules without parameters (microsoft#867)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* docs (microsoft#909)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Supporting different hidden dimensions for transformer kernels-v2 (microsoft#934)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* cleanup, reinstantiate sending of logits / layer_past

* cleanup, reinstantiate sending of logits / layer_past

* bump to 0.3.14

* add pypi badge

* Delete check of pdsh (microsoft#941)

* fix double linear override; spelling (microsoft#954)

* [config] turn exponential notation back on for config dump (microsoft#955)

* e-notation for large floats

* handle ints too

* readability

* handle bool

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* document how to override ~/.cache/torch_extensions (microsoft#959)

* [zero] faster flatten/unflatten (cpp version)  (microsoft#910)

* faster flatten/unflatten with apex

* switch to cpp flatten/unflatten

* style

* better comment

* missing import

* switch to build ops at run time

* fixes

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* update lr scheduler doc for doing per step or epoch update (microsoft#913)

* update lr scheduler doc for doing per step or epoch update

* work

* trigger build

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Fix ZeRO-3 UnboundLocalError (microsoft#968)

* Fix UnboundLocalError

* Get full partition size

* ZeRO-Infinity (microsoft#976)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>

* revert zero-inf change to launcher

* [docs] zero-inf updates

* bump to 0.3.15

* ZeRO-Infinity tutorial additions (microsoft#978)

* zinf tutorial

* more megatron integration docs

* [docs] add ZeRO-Inf news items

* refactor

* ZeRO-Infinity docs (microsoft#979)

* zinf tutorial

* more megatron integration docs

* ZInf + tiling docs

* [docs] zero-inf updates

* assert no Z2/Z3 with pipeline and fix some docs links (microsoft#980)

* add option to force multi-node launcher mode (microsoft#977)

* [ZeRO Infinity] Allow Init to take a dict for the deepspeed config  (microsoft#983)

* Add check to see if json file is already loaded

* Update doc

* Address review

* Remove doc comment

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* make bold+italic work without escaping _ (microsoft#775)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* remove debug prints: (microsoft#986)

* 1-bit LAMB optimizer (microsoft#970)

1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed.
Author: @conglongli, @awan-10, @samyam, Hanlin Tang, Yuxiong He
Paper: https://arxiv.org/abs/2104.06069

Co-authored-by: sdtblck <46172032+sdtblck@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Use odd shape tensor to represent parameter data in partitioned state (microsoft#981)

* use wierd shaped tensor to avoid silent failures when not registering externel params

* fix typo

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Make reduce scatter optional for ZeRO-1 as workaround (microsoft#971)

* Make reduce scatter optional for ZeRO-1 as workaround

* Make allreduce default for ZeRO 1

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Fix all Pipeline Module Parameters being sent to cuda:0 (microsoft#687)

* remove communicate overflow (already in utils.CheckOverflow)

Co-authored-by: sid <sidney.black@aleph-alpha.de>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: brett koonce <koonce@gmail.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: hamlet <gvvvv@163.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Takuya Makino <takuyamakino15@gmail.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Sean Naren <sean@grid.ai>
@stas00
Copy link
Collaborator

stas00 commented Apr 22, 2021

@SeanNaren, this is great - we need it too! thank you!

Just one suggestion - deepspeed.initialize names this same functionally arg differently. Would you be open to rename s/param_dict/config_params/? to match deepspeed.initialize? I guess it's a bit late to rename deepspeed.initialize's existing args .

see it here:

config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config

I am fine if it remains as you proposed, as yours is more intuitive naming than config_params - the latter says twice the same thing and doesn't say it's a dict.

If I had a choice I'd rename both to config_dict but it's too late :)

@SeanNaren
Copy link
Contributor Author

@SeanNaren, this is great - we need it too! thank you!

Just one suggestion - deepspeed.initialize names this same functionally arg differently. Would you be open to rename s/param_dict/config_params/? to match deepspeed.initialize? I guess it's a bit late to rename deepspeed.initialize's existing args .

see it here:

config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config

I am fine if it remains as you proposed, as yours is more intuitive naming than config_params - the latter says twice the same thing and doesn't say it's a dict.

If I had a choice I'd rename both to config_dict but it's too late :)

I'm glad finally I'm able to assist you out stas! Should be straight forward, I can make a PR :)

@stas00
Copy link
Collaborator

stas00 commented Apr 22, 2021

or should we just make it a single arg that can take path or dict? This is what we do in transformers:

def deepspeed_parse_config(ds_config):
    """
    If ``ds_config`` isn't already a dict, read it from the config file.

    If it's already a dict, return a copy of it, so that we can freely modify it.
    """
    dep_version_check("deepspeed")

    if isinstance(ds_config, dict):
        # Don't modify user's data should they want to reuse it (e.g. in tests), because once we
        # modified it, it will not be accepted here again, since some config params must be not set by users
        config = deepcopy(ds_config)
    elif isinstance(ds_config, str):
        with io.open(ds_config, "r", encoding="utf-8") as f:
            config = json.load(f)
    else:
        raise ValueError("expecting either a path to a config file or a pre-populated dict")

    return config

minus copy and dep_checking.

@SeanNaren
Copy link
Contributor Author

or should we just make it a single arg that can take path or dict? This is what we do in transformers:

def deepspeed_parse_config(ds_config):
    """
    If ``ds_config`` isn't already a dict, read it from the config file.

    If it's already a dict, return a copy of it, so that we can freely modify it.
    """
    dep_version_check("deepspeed")

    if isinstance(ds_config, dict):
        # Don't modify user's data should they want to reuse it (e.g. in tests), because once we
        # modified it, it will not be accepted here again, since some config params must be not set by users
        config = deepcopy(ds_config)
    elif isinstance(ds_config, str):
        with io.open(ds_config, "r", encoding="utf-8") as f:
            config = json.load(f)
    else:
        raise ValueError("expecting either a path to a config file or a pre-populated dict")

    return config

minus copy and dep_checking.

Also how we set it up in PL, which I think is natural! If @tjruwase is at agreement I can make the change here, such that deepspeed_config can be an dict object rather than a path. If we're worried about BW compatibility we can keep config_params still as an argument?

@tjruwase
Copy link
Contributor

@SeanNaren, I am fine with making the change. I think it is early days that BW compatibility is less priority given the long-term benefits of the change.

@stas00
Copy link
Collaborator

stas00 commented Apr 23, 2021

There was no new release yet since your addition, so I agree with @tjruwase, it should not be an issue of BC.

Thank you!

@SeanNaren
Copy link
Contributor Author

I'm at cross-roads on how far I should go in this refactor:

  1. For DeepSpeed Init, remove param_dict, making deepspeed_config either a JSON file or config dict. This introduces the initial logic in the PR
  2. make deepspeed_config an input from initialize which can also be a JSON file or a config dict, and continues to take priority over args.
  3. Unify references to config_params or params_dict to deepspeed_config

I think 1/2 are probably the most crucial but am curious what you guys think. 3 is more involved as it refactors a lot of variables defined across various config functions and files. cc @stas00 @tjruwase

@tjruwase
Copy link
Contributor

@SeanNaren, thanks for sharing these thoughts. I agree that we should 1 & 2 for now and defer 3 for later. To my understanding 3 has minimal if any user-facing impact.

@stas00
Copy link
Collaborator

stas00 commented Apr 27, 2021

Thank you, @SeanNaren for this proposal.

The crucial thing for HF is that Init won't save the config and try to use it later, but only pluck a few params it absolutely needs and drop it. This is because during model init we don't yet have the full config ready. We need the model object to get the hidden size to automatically fill out auto-params that rely on hidden size and we don't have the number of training steps figured out until much later. So our config happens in 2 stages, one immediate upon argparse and we are ready to feed Init, and the 2nd stage is at the start of train or eval.

That said I'm concerned that if we use unification and it's hard to tell there is a chance someone will try to save the config at Init stage and disregard the updated version during initialize, leading to errors.

So for me I think that Init should only extract immediately the 3-4 params it needs and drop the passed config, in which case it doesn't matter what you call it internally.

Therefore 1 for sure, 2 I'm impartial, but definitely is a good idea, and 3 instead of unifying I'd actually rename config_params to init_params so it's clear that we can only refer to it for Init-specific params and not as a normal config. In fact I propose that a user should be able to pass to Init a trimmed version of config that has only 4 params (or however many Init might ever need).

I hope this anti-unification logic is sensible.

If we do 2, then perhaps args should be deprecated and in time removed.

@SeanNaren
Copy link
Contributor Author

Thanks Stas! we can increment on this; I agree with being able to pass just the parameters required for NVMe, but also it would be nice to keep the option of just passing the entire config for extensibility in the future + for user simplicity

@stas00
Copy link
Collaborator

stas00 commented Apr 27, 2021

Oh, I'm not against passing the entire config, future-proofing is exactly that - I just want to make sure it's not saved and used later in deepspeed's normal initialize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants