Skip to content

Commit

Permalink
Add unit stage3 test for running model twice in one step
Browse files Browse the repository at this point in the history
If run model more than once in one training step, there may be issues.
Add unit test to catch these kinds of problems.

Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
  • Loading branch information
wenbinc-Bin committed Nov 13, 2024
1 parent 43433ac commit 93ef46a
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tests/unit/runtime/zero/test_zero_multiple_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
import pytest
import torch
from unit.simple_model import SimpleModel, random_dataloader


def test_stage3_multiple_model_run():
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": 3
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
}
hidden_dim, nlayers = 2048, 3
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
model_engine, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
data_loader = iter(random_dataloader(model=model_engine,
total_samples=10,
hidden_dim=hidden_dim,
device=model_engine.device))

for n, batch in enumerate(data_loader):
loss1 = model_engine(batch[0], batch[1])
with torch.no_grad():
loss2 = model_engine(batch[0], batch[1])
loss = loss1 + loss2
model_engine.backward(loss)
for name, submodule in model_engine.module.linears._modules.items():
assert hasattr(submodule, "ds_grads_remaining"), \
f"linears.{name} does not have variable ds_grads_remaining"
assert submodule.ds_grads_remaining == 0, \
f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})"
model_engine.step()

0 comments on commit 93ef46a

Please sign in to comment.