Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge changes from deepspeed master #24

Merged
merged 31 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
18a26f3
[WarmupDecayLR] fix log(0) & 1/log(1) bugs (#772)
stas00 Mar 12, 2021
35fd7cc
bump to v0.3.12
jeffra Mar 12, 2021
458ff02
Bug fix: Remove client optimizer param_group list item that does not …
cli99 Mar 12, 2021
73d762c
[doc] pipeline doc typos/improvements (#659)
stas00 Mar 14, 2021
4601885
Samyamr/inference hook fix (#851)
samyam Mar 15, 2021
a75d971
ZeRO Stage 2: Clear reduced gradients (#856)
tjruwase Mar 15, 2021
24335d4
[runner/launch] propagate the error (#854)
stas00 Mar 16, 2021
547d1c5
docs: minor spelling tweaks (#858)
brettkoonce Mar 16, 2021
871f304
Allow args to be optional in deepspeed.initialize (#825)
jeffra Mar 16, 2021
fa87a73
Fix ZeRO3 save_checkpoint (#857)
tjruwase Mar 16, 2021
7bcd72a
Make config objects json serializable (#862)
tjruwase Mar 16, 2021
12a53b4
bump version 0.3.13
jeffra Mar 16, 2021
68c8481
1-bit Adam v2 (#817)
conglongli Mar 16, 2021
10c0bea
consistent checkpoint filenaming (#865)
stas00 Mar 18, 2021
9e9f8cb
[doc] launcher (#868)
stas00 Mar 18, 2021
22d5a1f
[doc] pipeline (#888)
stas00 Mar 24, 2021
7f03282
[debug utils] see_memory_usage fixes (#890)
stas00 Mar 25, 2021
7531c6b
full fp32 weights reconstruction for zero 2+3 (#892)
stas00 Mar 26, 2021
39013dd
save_fp16_model consolidated for zero3 (#893)
stas00 Mar 27, 2021
7fcc891
Fix zero stage2 cpu_offload when some model trainable parameters skip…
ghosthamlet Mar 27, 2021
af2d8fc
update kramdown (#901)
jeffra Mar 30, 2021
23ff6cb
update backward api doc (#903)
jeffra Mar 30, 2021
c042264
Bump kramdown from 2.3.0 to 2.3.1 in /docs (#905)
dependabot[bot] Mar 30, 2021
8c9e16e
We're hiring! + integration posts
jeffra Mar 31, 2021
c6b497d
[website] We're hiring! + integration posts
jeffra Mar 31, 2021
c814abd
[website] we're hiring!
jeffra Mar 31, 2021
5d721e0
zero.Init() clarification (#880)
stas00 Apr 1, 2021
8db4fdf
disable pipe test (#915)
jeffra Apr 2, 2021
ab5534f
Add link to AML examples. (#916)
awan-10 Apr 2, 2021
c574788
Merge branch 'master' of https://github.com/microsoft/DeepSpeed into …
Apr 6, 2021
b58a8fa
Merge branch 'microsoft-master' into stella
Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Allow args to be optional in deepspeed.initialize (microsoft#825)
  • Loading branch information
jeffra authored Mar 16, 2021
commit 871f3048ad0d05e79f8835849b7a00656a14b3f4
12 changes: 8 additions & 4 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def _parse_version(version_str):
sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler


def initialize(args,
model,
def initialize(args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
Expand All @@ -62,8 +62,7 @@ def initialize(args,
"""Initialize the DeepSpeed Engine.

Arguments:
args: a dictionary containing local_rank and deepspeed_config
file location
args: an object containing local_rank and deepspeed_config fields. This is optional if `config_params` is passed.

model: Required: nn.module class before apply any wrappers

Expand All @@ -88,6 +87,9 @@ def initialize(args,
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.

config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
as a dictionary instead.

Returns:
A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``

Expand All @@ -108,6 +110,8 @@ def initialize(args,
__git_branch__),
ranks=[0])

assert model is not None, "deepspeed.initialize requires a model"

if not isinstance(model, PipelineModule):
engine = DeepSpeedEngine(args=args,
model=model,
Expand Down
20 changes: 10 additions & 10 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,10 @@ def _configure_with_arguments(self, args, mpu):
# After the distributed backend is initialized we are guaranteed the LOCAL_RANK
# environment variable is set. We must align args.local_rank to this value for
# backwards compatability with scripts relying on [args|self].local_rank containing
# the correct local rank info.
args.local_rank = int(os.environ['LOCAL_RANK'])
self.local_rank = args.local_rank
# the correct local rank info. _do_args_sanity_check will ensure this is the case.
self.local_rank = int(os.environ['LOCAL_RANK'])
if hasattr(args, 'local_rank'):
args.local_rank = self.local_rank

config_file = args.deepspeed_config if hasattr(args,
'deepspeed_config') else None
Expand All @@ -513,15 +514,14 @@ def _do_args_sanity_check(self, args):
assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config

local_rank_err = "DeepSpeed requires a command line parameter of --local_rank [int] and/or setting the LOCAL_RANK environment variable."
if hasattr(args, 'local_rank'):
assert type(args.local_rank) == int, local_rank_err
if "LOCAL_RANK" in os.environ and args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
assert "LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment variable, it is set by the deepspeed launcher, " \
"deepspeed.init_distributed, or the torch.distributed launcher. If using a different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
if hasattr(args, 'local_rank') and args.local_rank != None:
assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
if args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK"))
assert env_local_rank == args.local_rank, \
f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
else:
assert "LOCAL_RANK" in os.environ, local_rank_err

if self.config_params is None:
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
Expand Down
80 changes: 80 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,83 @@ def _helper():
model.step()

_helper()


def test_none_args(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

@distributed_test(world_size=1)
def _helper():
model = SimpleModel(hidden_dim=10)
model, _, _, _ = deepspeed.initialize(args=None, model=model, config_params=config_dict)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=10,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])

_helper()


def test_no_args(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

@distributed_test(world_size=1)
def _helper():
model = SimpleModel(hidden_dim=10)
model, _, _, _ = deepspeed.initialize(model=model, config_params=config_dict)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=10,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])

_helper()


def test_no_model(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

@distributed_test(world_size=1)
def _helper():
model = SimpleModel(hidden_dim=10)
with pytest.raises(AssertionError):
model, _, _, _ = deepspeed.initialize(model=None, config_params=config_dict)

with pytest.raises(AssertionError):
model, _, _, _ = deepspeed.initialize(model, config_params=config_dict)