Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split parameter offload from z3 #2009

Merged
merged 16 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload

from deepspeed.runtime.activation_checkpointing import (
checkpointing as activation_checkpointing,
)

from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
Expand Down Expand Up @@ -328,7 +331,8 @@ def __init__(

self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
self._configure_checkpointing(dist_init_required)
if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
self._configure_checkpointing(dist_init_required)

if self.eigenvalue_enabled():
self.eigenvalue = self._configure_eigenvalue()
Expand Down Expand Up @@ -1385,7 +1389,6 @@ def _configure_zero_optimizer(self, optimizer):
"Pipeline parallelism does not support overlapped communication, will be disabled."
)
overlap_comm = False

optimizer = DeepSpeedZeroOptimizer(
optimizer,
timers=timers,
Expand Down Expand Up @@ -1422,33 +1425,47 @@ def _configure_zero_optimizer(self, optimizer):
logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3

optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
communication_data_type=self.communication_data_type)
if isinstance(optimizer, DummyOptim):
optimizer = DeepSpeedZeRoOffload(
self.module,
timers=timers,
ds_config=self.config,
overlap_comm=self.zero_overlap_comm(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu)
else:

optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
communication_data_type=self.communication_data_type)

else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
Expand Down
Loading