From adec99121b411709e1b185a486d18aa846c82c64 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:59:26 -0700 Subject: [PATCH] Add API to get devices of offload states (#6586) This PR adds an API `deepspeed.runtime.zero.offload_states get_state_devices`, which gets devices of offload states as suggested in this [comment](https://github.com/microsoft/DeepSpeed/pull/6011#issuecomment-2358068777). We could lift this up to `deepspeed.utils` but would need to resolve a circular import: User code -> `deepspeed.utils` -> `deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` -> `deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils` This will require a significant refactoring as long as we have `OffloadStateTypeEnum` in `deepspeed.runtime.zero`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/utils.py | 10 +-- deepspeed/runtime/zero/offload_states.py | 74 +++++++++++++++++++ deepspeed/runtime/zero/stage3.py | 3 +- docs/code-docs/source/zero3.rst | 16 ++++ .../unit/runtime/zero/test_offload_states.py | 23 +++--- 5 files changed, 110 insertions(+), 16 deletions(-) create mode 100644 deepspeed/runtime/zero/offload_states.py diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index adcadd349803..b9617d3e632f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -9,28 +9,28 @@ """ from collections.abc import Iterable -from deepspeed.moe.utils import is_moe_param import os import psutil import gc from math import sqrt +from numpy import prod + import torch -from deepspeed import comm as dist +from torch.nn import functional as F try: from torch._six import inf except ModuleNotFoundError: from torch import inf +from deepspeed import comm as dist +from deepspeed.moe.utils import is_moe_param from deepspeed.utils import groups, logger from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size, bwc_pipeline_parallel_group) from deepspeed.runtime.constants import PIPE_REPLICATED -from numpy import prod from deepspeed.accelerator import get_accelerator - from deepspeed.module_inject.policy import transpose -from torch.nn import functional as F torch_memory_reserved = get_accelerator().memory_reserved torch_max_memory_reserved = get_accelerator().max_memory_reserved diff --git a/deepspeed/runtime/zero/offload_states.py b/deepspeed/runtime/zero/offload_states.py new file mode 100644 index 000000000000..f521a11a7aa4 --- /dev/null +++ b/deepspeed/runtime/zero/offload_states.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Set +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum + +from deepspeed.utils.tensor_fragment import safe_get_local_fp32_param, safe_get_local_optimizer_state + + +def _make_offload_state_key(key): + return f"{key}_offload_buffer" + + +def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_key(state, key): + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = torch.empty_like(state[key], device=device) + if pin_memory: + state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key]) + state[offload_buf_key].copy_(state[key], non_blocking=non_blocking) + state[key].data = state[offload_buf_key] + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + +def reload_adam_states(optimizer, device, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_back_key(state, key): + state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking) + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_back_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_back_key(state, "exp_avg_sq") + + +def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]: + """Retrieve the devices of the specified state of the model. + + Args: + model (DeepSpeedEngine): The model whose device allocations are to be checked. + state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved. + + Returns: + Set[torch.device]: A set of devices of the specified state. + + """ + if state == OffloadStateTypeEnum.hp_params: + return set(safe_get_local_fp32_param(p).device for p in model.parameters()) + elif state == OffloadStateTypeEnum.lp_params: + return set(p.ds_tensor.device for p in model.parameters()) + elif state == OffloadStateTypeEnum.lp_grads: + return {model.optimizer.grad_partitions_flat_buffer.device} + elif state == OffloadStateTypeEnum.optim_states: + return set(safe_get_local_optimizer_state(p, "exp_avg").device for p in model.parameters()) | \ + set(safe_get_local_optimizer_state(p, "exp_avg_sq").device for p in model.parameters()) + elif state == OffloadStateTypeEnum.contiguous_grad_buffer: + if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None: + return {} + return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device} diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index fb75d2bcebd5..6895916783f1 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -18,12 +18,13 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer +from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index f0974c08c9f3..ae7cedd1a8b3 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -509,3 +509,19 @@ Below is an example code snippet demonstrating how to offload FP32 parameters an ... # Load states back to device memory ds_engine.reload_states() + +``deepspeed.runtime.zero.offload_states.get_state_devices`` returns devices of the specified state. + +.. code-block:: python + + def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]: + """Retrieve the devices of the specified state of the model. + + Args: + model (DeepSpeedEngine): The model whose device allocations are to be checked. + state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved. + + Returns: + Set[torch.device]: A set of devices of the specified state. + + """ diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index cc60908d3c33..9105a54661fa 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -15,19 +15,22 @@ import deepspeed from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state +from deepspeed.runtime.zero.offload_states import get_state_devices def validate_device(model, device: torch.device, include) -> None: - # Make sure the model parameters are offloaded - if include is None or OffloadStateTypeEnum.hp_params in include: - assert all(safe_get_local_fp32_param(p).device == device for p in model.parameters()) - if include is None or OffloadStateTypeEnum.lp_params in include: - assert all(p.ds_tensor.device == device for p in model.parameters()) - if include is None or OffloadStateTypeEnum.lp_grads in include: - assert model.optimizer.grad_partitions_flat_buffer.device == device - if include is None or OffloadStateTypeEnum.optim_states in include: - assert all(safe_get_local_optimizer_state(p, "exp_avg").device == device for p in model.parameters()) - assert all(safe_get_local_optimizer_state(p, "exp_avg_sq").device == device for p in model.parameters()) + + def compare_device(state) -> bool: + devices = get_state_devices(model, state) + return len(devices) == 1 and device in devices + + for state in OffloadStateTypeEnum: + if include is None or state in include: + if state == OffloadStateTypeEnum.contiguous_grad_buffer and device == torch.device("cpu"): + assert len(get_state_devices(model, + state)) == 0, f"State {state} must be removed after offload_states()" + else: + assert compare_device(state), f"State {state} is not on device {device}" def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):