Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unused parameters assert by default #1039

Merged
merged 8 commits into from
May 7, 2021
6 changes: 3 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ def zero_param_persistence_threshold(self):
def zero_gather_fp16_weights_on_model_save(self):
return self._config.zero_config.gather_fp16_weights_on_model_save

def zero_find_unused_parameters(self):
return self._config.zero_config.find_unused_parameters
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters

def fp16_enabled(self):
return self._config.fp16_enabled
Expand Down Expand Up @@ -786,7 +786,7 @@ def _configure_zero_optimizer(self, optimizer):
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
find_unused_parameters=self.zero_find_unused_parameters())
ignore_unused_parameters=self.zero_ignore_unused_parameters())
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
Expand Down
8 changes: 4 additions & 4 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, param_dict):
self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None

self.find_unused_parameters = None
self.ignore_unused_parameters = None

if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
Expand Down Expand Up @@ -178,7 +178,7 @@ def _initialize(self, zero_config_dict):
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)

self.find_unused_parameters = get_scalar_param(
self.ignore_unused_parameters = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS,
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS_DEFAULT)
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS,
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT)
12 changes: 5 additions & 7 deletions deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"sub_group_size" : 1000000000000,
"offload_param": {...},
"offload_optimizer": {...},
"find_unused_parameters": [true|false]
"ignore_unused_parameters": [true|false]
}
}
'''
Expand Down Expand Up @@ -117,10 +117,8 @@
# Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload
# Enable this option to avoid:
# https://github.com/microsoft/DeepSpeed/issues/707
# torch.nn.parallel.DistributedDataParallel has the same option with
# similar usage
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS = 'find_unused_parameters'
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS_DEFAULT = False
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS = 'ignore_unused_parameters'
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT = True

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
Expand Down Expand Up @@ -155,6 +153,6 @@
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT,
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS:
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS_DEFAULT
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS:
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT
}
13 changes: 6 additions & 7 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
find_unused_parameters=False):
ignore_unused_parameters=True):

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self,
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.find_unused_parameters = find_unused_parameters
self.ignore_unused_parameters = ignore_unused_parameters

if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
Expand Down Expand Up @@ -896,12 +896,11 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
# As unused parameters in modules may not be expected sometimes,
# add an explicit error msg when it occurred and an option to
# avoid the error
# Error msg adapted from torch.nn.parallel.DistributedDataParallel
assert self.find_unused_parameters, """
This error indicates that your module has parameters that
assert self.ignore_unused_parameters, """
This assert indicates that your module has parameters that
were not used in producing loss.
You can avoid this error by
(1) enable find_unused_parameters option in zero_optimization config;
You can avoid this assert by
(1) enable ignore_unused_parameters option in zero_optimization config;
(2) making sure all trainable parameters and `forward` function
outputs participate in calculating loss.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import deepspeed


@pytest.mark.parametrize('find_unused_parameters', [False, True])
def test_stage2_find_unused_parameters(tmpdir, find_unused_parameters):
@pytest.mark.parametrize('ignore_unused_parameters', [False, True])
def test_stage2_ignore_unused_parameters(tmpdir, ignore_unused_parameters):
use_cpu_offload = True

if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
Expand All @@ -24,7 +24,7 @@ def test_stage2_find_unused_parameters(tmpdir, find_unused_parameters):
"zero_optimization": {
"stage": 2,
"cpu_offload": use_cpu_offload,
"find_unused_parameters": find_unused_parameters
"ignore_unused_parameters": ignore_unused_parameters
},
"optimizer": {
"type": "Adam",
Expand All @@ -44,7 +44,7 @@ def test_stage2_find_unused_parameters(tmpdir, find_unused_parameters):
model = UnusedParametersModel(hidden_dim=hidden_dim)

@distributed_test(world_size=[1])
def _test_stage2_find_unused_parameters(args, model, hidden_dim):
def _test_stage2_ignore_unused_parameters(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
Expand All @@ -60,11 +60,11 @@ def _loop():
model.backward(loss)
model.step()

if not find_unused_parameters:
if ignore_unused_parameters:
_loop()
else:
with pytest.raises(AssertionError) as e:
_loop()
assert e.value.args and 'find_unused_parameters' in e.value.args[0]
else:
_loop()
assert e.value.args and 'ignore_unused_parameters' in e.value.args[0]

_test_stage2_find_unused_parameters(args=args, model=model, hidden_dim=hidden_dim)
_test_stage2_ignore_unused_parameters(args=args, model=model, hidden_dim=hidden_dim)