Skip to content

Commit

Permalink
[ZeRO-1] fix bug w. cpu-offload + > 1 GPU (microsoft#1841)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Mar 16, 2022
1 parent 18ea8b7 commit 28434c0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,8 @@ def reduce_ipg_grads(self):
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
else: # zero stage 1 - partition only optimizer state
if self.contiguous_gradients:
if self.contiguous_gradients and self.is_param_in_current_partition[
param_id]:
self.copy_grads_in_partition(param)

self.grads_in_ipg_bucket = []
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,3 +1179,46 @@ def create_tensor(vals):
_assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})

_test_zero3_param_partitioning()


def test_zero_offload_stage1():
config_dict = {
"train_batch_size": 4,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "cpu"
}
}
}

hidden_dim = 10
model = SimpleModel(hidden_dim)

@distributed_test(world_size=[2])
def _go(model, hidden_dim):
model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
torch.distributed.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_go(model=model, hidden_dim=hidden_dim)

0 comments on commit 28434c0

Please sign in to comment.