Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge changes from deepspeed master #24

Merged
merged 31 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
18a26f3
[WarmupDecayLR] fix log(0) & 1/log(1) bugs (#772)
stas00 Mar 12, 2021
35fd7cc
bump to v0.3.12
jeffra Mar 12, 2021
458ff02
Bug fix: Remove client optimizer param_group list item that does not …
cli99 Mar 12, 2021
73d762c
[doc] pipeline doc typos/improvements (#659)
stas00 Mar 14, 2021
4601885
Samyamr/inference hook fix (#851)
samyam Mar 15, 2021
a75d971
ZeRO Stage 2: Clear reduced gradients (#856)
tjruwase Mar 15, 2021
24335d4
[runner/launch] propagate the error (#854)
stas00 Mar 16, 2021
547d1c5
docs: minor spelling tweaks (#858)
brettkoonce Mar 16, 2021
871f304
Allow args to be optional in deepspeed.initialize (#825)
jeffra Mar 16, 2021
fa87a73
Fix ZeRO3 save_checkpoint (#857)
tjruwase Mar 16, 2021
7bcd72a
Make config objects json serializable (#862)
tjruwase Mar 16, 2021
12a53b4
bump version 0.3.13
jeffra Mar 16, 2021
68c8481
1-bit Adam v2 (#817)
conglongli Mar 16, 2021
10c0bea
consistent checkpoint filenaming (#865)
stas00 Mar 18, 2021
9e9f8cb
[doc] launcher (#868)
stas00 Mar 18, 2021
22d5a1f
[doc] pipeline (#888)
stas00 Mar 24, 2021
7f03282
[debug utils] see_memory_usage fixes (#890)
stas00 Mar 25, 2021
7531c6b
full fp32 weights reconstruction for zero 2+3 (#892)
stas00 Mar 26, 2021
39013dd
save_fp16_model consolidated for zero3 (#893)
stas00 Mar 27, 2021
7fcc891
Fix zero stage2 cpu_offload when some model trainable parameters skip…
ghosthamlet Mar 27, 2021
af2d8fc
update kramdown (#901)
jeffra Mar 30, 2021
23ff6cb
update backward api doc (#903)
jeffra Mar 30, 2021
c042264
Bump kramdown from 2.3.0 to 2.3.1 in /docs (#905)
dependabot[bot] Mar 30, 2021
8c9e16e
We're hiring! + integration posts
jeffra Mar 31, 2021
c6b497d
[website] We're hiring! + integration posts
jeffra Mar 31, 2021
c814abd
[website] we're hiring!
jeffra Mar 31, 2021
5d721e0
zero.Init() clarification (#880)
stas00 Apr 1, 2021
8db4fdf
disable pipe test (#915)
jeffra Apr 2, 2021
ab5534f
Add link to AML examples. (#916)
awan-10 Apr 2, 2021
c574788
Merge branch 'master' of https://github.com/microsoft/DeepSpeed into …
Apr 6, 2021
b58a8fa
Merge branch 'microsoft-master' into stella
Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix ZeRO3 save_checkpoint (microsoft#857)
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
tjruwase and jeffra authored Mar 16, 2021
commit fa87a73a8a3bead24ad9ea52090646fa620d74e8
12 changes: 5 additions & 7 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,7 +2269,7 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id):

assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[sub_group_id], sub_group_id, partition_id)
single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id)

self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition

Expand Down Expand Up @@ -2638,14 +2638,12 @@ def get_groups_without_padding(self, groups_with_padding):
def _set_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [
self.fp32_partitioned_groups_flat[sub_group_id]
]
self.optimizer.param_groups[param_group_id]['params'].append(
self.fp32_partitioned_groups_flat[sub_group_id])

def _clear_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = []
for param_group in self.optimizer.param_groups:
param_group['params'] = []

def _rigid_state_dict(self):
state_dict = {}
Expand Down
44 changes: 24 additions & 20 deletions tests/unit/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if FP16_DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
saved_model.optimizer,
FP16_DeepSpeedZeroOptimizer_Stage3):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat):
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"

elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
Expand Down Expand Up @@ -303,12 +303,13 @@ def _test_checkpoint_fused_optimizer(args,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand All @@ -324,8 +325,10 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
}
},
"fp16": {
"enabled": True
"enabled": True,
"initial_scale_power": 8
},
"wall_clock_breakdown": True,
"zero_optimization": {
"stage": zero_stage,
"cpu_offload": use_cpu_offload
Expand All @@ -340,9 +343,7 @@ def _test_checkpoint_zero_optimizer(args,
hidden_dim,
load_optimizer_states):
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down Expand Up @@ -371,15 +372,16 @@ def _test_checkpoint_zero_optimizer(args,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_zero_no_optimizer(tmpdir,
zero_stage,
use_cpu_offload,
adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -413,7 +415,7 @@ def _test_checkpoint_zero_no_optimizer(args,
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down Expand Up @@ -445,12 +447,13 @@ def _test_checkpoint_zero_no_optimizer(args,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -493,7 +496,7 @@ def _test_checkpoint_lr_scheduler(args,
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down Expand Up @@ -526,14 +529,15 @@ def _test_checkpoint_lr_scheduler(args,
(2,
True,
'deepspeed_adam'),
(3,
False,
'Adam'),
(3,
True,
'Adam')])
'deepspeed_adam')])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -570,7 +574,7 @@ def _test_checkpoint_no_lr_scheduler(args,
load_optimizer_states,
load_lr_scheduler_states):
if zero_stage == 3:
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down