Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Empty ZeRO3 partition cache #3060

Merged
merged 11 commits into from
Mar 24, 2023
10 changes: 10 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hashlib
from collections import defaultdict, OrderedDict, deque
from shutil import copyfile
import gc

from torch.nn.modules import Module
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -3546,3 +3547,12 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
self.checkpoint_engine.commit(tag)

return True

def empty_partition_cache(self):
"""
Release GPU memory consumed by offloaded model parameters.
"""
if hasattr(self.optimizer, 'empty_partition_cache'):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()
5 changes: 4 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ def get_param_coordinator(self, training):

return self.param_coordinators[training]

def empty_partition_cache(self):
self.partition_all_parameters()

def _convert_to_zero_parameters(self, ds_config, module, mpu):
non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
if non_zero_params:
Expand Down Expand Up @@ -321,7 +324,7 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):
if param.ds_numel + total_persistent_parameters > model_threshold:
continue

if param.ds_numel < param_threshold:
if param.ds_numel <= param_threshold:
params_count += 1
param.ds_persist = True
persistent_params.append(param)
Expand Down
3 changes: 3 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,9 @@ def checkpoint_event_epilogue(self):
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].all_gather(self.persistent_parameters)

def empty_partition_cache(self):
self.parameter_offload.empty_partition_cache()


def _handle_overflow(cpu_sum, x, i):
import math
Expand Down
26 changes: 26 additions & 0 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,29 @@ These routines can be used in a training loop as shown in the following snippet.

[...]
optimizer.step()


GPU Memory Management
---------------------

By default at the end of training with ZeRO stage 3 some parameters could remain unpartitioned and use up some gpu memory.
This is done on purpose as an optimization should you resume training again. If you'd like to clear out the cached
parameters that use up gpu memory, you can call ``empty_partition_cache`` method of a DeepSpeed engine.

.. autofunction::deepspeed.DeepSpeedEngine.empty_partition_cache

The following code snippet illustrates this functionality.

.. code-block:: python

with zero.Init():
model = MyLargeModel()

ds_engine, _, _, _ = deepspeed.initialize(model, ...)
for batch in ...:
loss = ds_engine(batch)
ds_engine.backward(batch)
ds_engine.step()

# Free GPU memory consumed by model parameters
ds_engine.empty_partition_cache()
46 changes: 46 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,3 +1422,49 @@ def test(self, force_ds_optim):
model, _, _, _ = deepspeed.initialize(model=model,
optimizer=optimizer,
config=config_dict)


@pytest.mark.parametrize('training', [True, False])
class TestZeroPartitionCache(DistributedTest):
world_size = 1

def test_training_partition_cache(self, training):
hidden_dim = 10
config_dict = {
"train_batch_size": 2,
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": hidden_dim
}
}
if training:
config_dict["optimizer"] = {"type": "Adam"}

with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, empty_grad=False)

model, _, _, _ = deepspeed.initialize(model=model, config=config_dict)

dtype = torch.half
data_loader = random_dataloader(model=model,
total_samples=6,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)

for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if training:
model.backward(loss)
model.step()

persist_param_size = sum([p.numel() for p in model.parameters() if p.ds_persist])

assert persist_param_size >= sum([p.numel() for p in model.parameters()])

model.empty_partition_cache()
assert sum([p.numel() for p in model.parameters()]) == 0