Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge changes from deepspeed master #24

Merged
merged 31 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
18a26f3
[WarmupDecayLR] fix log(0) & 1/log(1) bugs (#772)
stas00 Mar 12, 2021
35fd7cc
bump to v0.3.12
jeffra Mar 12, 2021
458ff02
Bug fix: Remove client optimizer param_group list item that does not …
cli99 Mar 12, 2021
73d762c
[doc] pipeline doc typos/improvements (#659)
stas00 Mar 14, 2021
4601885
Samyamr/inference hook fix (#851)
samyam Mar 15, 2021
a75d971
ZeRO Stage 2: Clear reduced gradients (#856)
tjruwase Mar 15, 2021
24335d4
[runner/launch] propagate the error (#854)
stas00 Mar 16, 2021
547d1c5
docs: minor spelling tweaks (#858)
brettkoonce Mar 16, 2021
871f304
Allow args to be optional in deepspeed.initialize (#825)
jeffra Mar 16, 2021
fa87a73
Fix ZeRO3 save_checkpoint (#857)
tjruwase Mar 16, 2021
7bcd72a
Make config objects json serializable (#862)
tjruwase Mar 16, 2021
12a53b4
bump version 0.3.13
jeffra Mar 16, 2021
68c8481
1-bit Adam v2 (#817)
conglongli Mar 16, 2021
10c0bea
consistent checkpoint filenaming (#865)
stas00 Mar 18, 2021
9e9f8cb
[doc] launcher (#868)
stas00 Mar 18, 2021
22d5a1f
[doc] pipeline (#888)
stas00 Mar 24, 2021
7f03282
[debug utils] see_memory_usage fixes (#890)
stas00 Mar 25, 2021
7531c6b
full fp32 weights reconstruction for zero 2+3 (#892)
stas00 Mar 26, 2021
39013dd
save_fp16_model consolidated for zero3 (#893)
stas00 Mar 27, 2021
7fcc891
Fix zero stage2 cpu_offload when some model trainable parameters skip…
ghosthamlet Mar 27, 2021
af2d8fc
update kramdown (#901)
jeffra Mar 30, 2021
23ff6cb
update backward api doc (#903)
jeffra Mar 30, 2021
c042264
Bump kramdown from 2.3.0 to 2.3.1 in /docs (#905)
dependabot[bot] Mar 30, 2021
8c9e16e
We're hiring! + integration posts
jeffra Mar 31, 2021
c6b497d
[website] We're hiring! + integration posts
jeffra Mar 31, 2021
c814abd
[website] we're hiring!
jeffra Mar 31, 2021
5d721e0
zero.Init() clarification (#880)
stas00 Apr 1, 2021
8db4fdf
disable pipe test (#915)
jeffra Apr 2, 2021
ab5534f
Add link to AML examples. (#916)
awan-10 Apr 2, 2021
c574788
Merge branch 'master' of https://github.com/microsoft/DeepSpeed into …
Apr 6, 2021
b58a8fa
Merge branch 'microsoft-master' into stella
Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Samyamr/inference hook fix (microsoft#851)
* Fix mis-aligned-grad

When a parameter is not divisible by world size, the partitioned gradients are mis-aligned due to incorrect padding handling. This PR should fix for that.

* Formatting fix

* Adding static_scale test back for Z3, and also changing hidden size to be not divisile by world_size

* also removing alignment from flat fp16 buffers

* Testing for hidden dim alignment

* inference hook fix

* Update stage3.py

* formatting

* [bug-fix] move params to gpu if offload params is turned off

Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
3 people authored Mar 15, 2021
commit 4601885972be96373066662084ce1bf9c49448b8
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,8 +807,12 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
if start < param.ds_numel:
elements = min(param.ds_numel - start, partition_size)

dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements)
dest_tensor_full_buffer = partition_buffer.view(-1).narrow(
0,
0,
partition_size)

dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
src_tensor = param.grad.view(-1).narrow(0, start, elements)

# just copy the grad partition to the buffer
Expand Down Expand Up @@ -841,7 +845,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
# elements))

#print("after partition gradients")
param.grad.data = dest_tensor.data
param.grad.data = dest_tensor_full_buffer.data
see_memory_usage("After partitioning gradients", force=False)


Expand Down
21 changes: 15 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,9 @@ def _create_fp16_partitions_with_defragmentation(self):

#create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i],
1).cuda(
torch.cuda.current_device()))
see_memory_usage(
f"After flattening and moving param group {i} to GPU",
force=False)
Expand All @@ -976,10 +975,12 @@ def _create_fp16_partitions_with_defragmentation(self):
flat_offset,
total_elements)
self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])
flat_offset += total_elements

# move param to flat buffer for both param offload on/off
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])

see_memory_usage(f"After Flattening param group {i}", force=False)

def _create_fp32_partitions(self):
Expand Down Expand Up @@ -1036,6 +1037,14 @@ def setup_zero_stage3_hooks(self):
self.hierarchy = 0
self._register_hooks_recursively(self.module)

#reset step if in inference mode
def _end_of_forward_hook(module, *args):

if not torch._C.is_grad_enabled():
self.param_coordinator.reset_step()

self.module.register_forward_hook(_end_of_forward_hook)

def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")

if zero_stage == 3:
pytest.skip("skip for now")

config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
Expand All @@ -371,8 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=2)
def _test_zero_static_scale(args, zero_stage):
hidden_dim = 10
def _test_zero_static_scale(args, zero_stage, hidden_dim):
#making hidden size not divisible by DP for covering this scenario
hidden_dim = hidden_dim
model = SimpleModel(hidden_dim)

model, optim, _, _ = deepspeed.initialize(args=args,
Expand All @@ -393,7 +391,10 @@ def _test_zero_static_scale(args, zero_stage):
model.backward(loss)
model.step()

_test_zero_static_scale(args=args, zero_stage=zero_stage)
#test when hidden_dim is not aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9)
#test when hidden_dim is aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10)


def test_zero_static_scale_deprecated_format(tmpdir):
Expand Down