Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge changes from deepspeed master #24

Merged
merged 31 commits into from
Apr 6, 2021
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
18a26f3
[WarmupDecayLR] fix log(0) & 1/log(1) bugs (#772)
stas00 Mar 12, 2021
35fd7cc
bump to v0.3.12
jeffra Mar 12, 2021
458ff02
Bug fix: Remove client optimizer param_group list item that does not …
cli99 Mar 12, 2021
73d762c
[doc] pipeline doc typos/improvements (#659)
stas00 Mar 14, 2021
4601885
Samyamr/inference hook fix (#851)
samyam Mar 15, 2021
a75d971
ZeRO Stage 2: Clear reduced gradients (#856)
tjruwase Mar 15, 2021
24335d4
[runner/launch] propagate the error (#854)
stas00 Mar 16, 2021
547d1c5
docs: minor spelling tweaks (#858)
brettkoonce Mar 16, 2021
871f304
Allow args to be optional in deepspeed.initialize (#825)
jeffra Mar 16, 2021
fa87a73
Fix ZeRO3 save_checkpoint (#857)
tjruwase Mar 16, 2021
7bcd72a
Make config objects json serializable (#862)
tjruwase Mar 16, 2021
12a53b4
bump version 0.3.13
jeffra Mar 16, 2021
68c8481
1-bit Adam v2 (#817)
conglongli Mar 16, 2021
10c0bea
consistent checkpoint filenaming (#865)
stas00 Mar 18, 2021
9e9f8cb
[doc] launcher (#868)
stas00 Mar 18, 2021
22d5a1f
[doc] pipeline (#888)
stas00 Mar 24, 2021
7f03282
[debug utils] see_memory_usage fixes (#890)
stas00 Mar 25, 2021
7531c6b
full fp32 weights reconstruction for zero 2+3 (#892)
stas00 Mar 26, 2021
39013dd
save_fp16_model consolidated for zero3 (#893)
stas00 Mar 27, 2021
7fcc891
Fix zero stage2 cpu_offload when some model trainable parameters skip…
ghosthamlet Mar 27, 2021
af2d8fc
update kramdown (#901)
jeffra Mar 30, 2021
23ff6cb
update backward api doc (#903)
jeffra Mar 30, 2021
c042264
Bump kramdown from 2.3.0 to 2.3.1 in /docs (#905)
dependabot[bot] Mar 30, 2021
8c9e16e
We're hiring! + integration posts
jeffra Mar 31, 2021
c6b497d
[website] We're hiring! + integration posts
jeffra Mar 31, 2021
c814abd
[website] we're hiring!
jeffra Mar 31, 2021
5d721e0
zero.Init() clarification (#880)
stas00 Apr 1, 2021
8db4fdf
disable pipe test (#915)
jeffra Apr 2, 2021
ab5534f
Add link to AML examples. (#916)
awan-10 Apr 2, 2021
c574788
Merge branch 'master' of https://github.com/microsoft/DeepSpeed into …
Apr 6, 2021
b58a8fa
Merge branch 'microsoft-master' into stella
Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ZeRO Stage 2: Clear reduced gradients (microsoft#856)
* Ensure gradients of other partitions are cleared after reduction

* Remove redundant code

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
tjruwase and jeffra authored Mar 15, 2021
commit a75d971bc2f1300c10331ed3b5f6026ecabe1821
23 changes: 15 additions & 8 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def split_half_float_double(tensors):
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t is not None and t.type() == dtype]
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
Expand Down Expand Up @@ -477,6 +477,8 @@ def independent_gradient_partition_epilogue(self):

if self.overlap_comm:
torch.cuda.synchronize()
# It is safe to clear previously reduced grads of other partitions
self._clear_previous_reduced_grads()

if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):
Expand Down Expand Up @@ -638,6 +640,9 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
param.grad.data = new_grad_tensor.data.view_as(param.grad)

self.elements_in_ipg_bucket += param.numel()

assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))

Expand Down Expand Up @@ -965,7 +970,7 @@ def reduce_ipg_grads(self):

if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear the previous grads during the next reduction
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
Expand Down Expand Up @@ -1078,16 +1083,18 @@ def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=N

return tensor

def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None

#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear.
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None
# It is safe to clear the previously reduced grads of other partitions
self._clear_previous_reduced_grads()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
Expand Down