Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential memory issues when use deepspeed Z3 #6726

Merged
merged 9 commits into from
Nov 21, 2024
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def _run_before_forward_function(input):
_run_after_backward_hook, inputs)

def _post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
if not hasattr(module, "ds_grads_remaining"):
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
module.ds_grads_remaining = 0

if not hasattr(module, "post_bwd_fn"):

Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def reset_step(self) -> None:
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
self.__step_id = 0
self.__n_available_params = 0
self.__profiler.reset_events()

def _dump_params(self, tag, sub_module, params, step_id=None):
Expand Down Expand Up @@ -430,7 +429,7 @@ def release_and_reset_all(self, module: Module) -> None:
# there's a hook execution issue
param.ds_active_sub_modules.clear()
self.__release_param(param)

self.__n_available_params = 0
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
for param in iter_params(module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"{param.ds_summary()} expected to be released")
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/runtime/zero/test_zero_multiple_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
import torch
from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import SimpleModel, random_dataloader


class TestZ3MultipleModelCall(DistributedTest):
world_size = 1

def test_z3_multiple_model_call(self):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": 3
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
hidden_dim, nlayers = 2048, 3
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
model_engine, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())
data_loader = iter(
random_dataloader(model=model_engine, total_samples=10, hidden_dim=hidden_dim, device=model_engine.device))

for n, batch in enumerate(data_loader):
loss1 = model_engine(batch[0], batch[1])
with torch.no_grad():
loss2 = model_engine(batch[0], batch[1])
loss = loss1 + loss2
model_engine.backward(loss)
for name, submodule in model_engine.module.linears._modules.items():
assert hasattr(submodule, "ds_grads_remaining"), \
f"linears.{name} does not have variable ds_grads_remaining"
assert submodule.ds_grads_remaining == 0, \
f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})"
model_engine.step()