Skip to content

Commit

Permalink
Add save_ignore_keys (#2868)
Browse files Browse the repository at this point in the history
* comment

* add it

* debug

* add the keys

* debug

* debug

* remove print statement

* docs and tests

* fix tests

---------

Co-authored-by: Daniel King <daniel@mosaicml.com>
  • Loading branch information
mvpatel2000 and dakinggg authored Jan 16, 2024
1 parent 1bc8d0a commit 31ea664
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 6 deletions.
32 changes: 28 additions & 4 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,6 @@ class CheckpointSaver(Callback): # noqa: D101
progress). It should return ``True`` if a checkpoint should be saved given the current state and
event.
weights_only (bool): If ``True``, save only the model weights instead of the entire training state.
This parameter must be ``False`` when using DeepSpeed. Default: ``False``.
num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints
are removed first. Set to ``-1`` to keep all checkpoints locally. Default: ``-1``.
Expand All @@ -214,6 +210,31 @@ class CheckpointSaver(Callback): # noqa: D101
This parameter only controls how many checkpoints are kept locally; checkpoints are not deleted from
remote file systems.
weights_only (bool): If ``True``, save only the model weights instead of the entire training state.
This parameter must be ``False`` when using DeepSpeed. Default: ``False``.
ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list
of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch
uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2).
See :mod:`composer.core.state` for the structure of state_dict.
Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore
layer 1 weights and bias.
Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same
effect as the previous example if there was only 1 layer.
Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model.
Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when
saving the checkpoint.
If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify
the state_dict before it is loaded.
(default: ``None``)
Attributes:
saved_checkpoints (List[Tuple[Timestamp, List[pathlib.Path]]]): The checkpoint timestamps and filepaths.
Expand Down Expand Up @@ -243,6 +264,7 @@ def __init__(
overwrite: bool = False,
num_checkpoints_to_keep: int = -1,
weights_only: bool = False,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
):
folder = str(folder)
filename = str(filename)
Expand All @@ -267,6 +289,7 @@ def __init__(
self.all_saved_checkpoints_to_timestamp: Dict[str, Timestamp] = {}
self.num_checkpoints_to_keep = num_checkpoints_to_keep
self.weights_only = weights_only
self.ignore_keys = ignore_keys

self.start_batch = None

Expand Down Expand Up @@ -363,6 +386,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state=state,
filename=filename_with_placeholders,
weights_only=self.weights_only,
ignore_keys=self.ignore_keys,
)
log.debug(f'Checkpoint locally saved to {saved_path}')

Expand Down
23 changes: 23 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,27 @@ class Trainer:
state. This parameter has no effect if ``save_folder`` is ``None``. (default: ``False``)
.. seealso:: :class:`~.CheckpointSaver`
save_ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list
of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch
uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2).
See :mod:`composer.core.state` for the structure of state_dict.
Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore
layer 1 weights and bias.
Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same
effect as the previous example if there was only 1 layer.
Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model.
Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when
saving the checkpoint.
If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify
the state_dict before it is loaded.
(default: ``None``)
save_num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints
are removed first. Set to ``-1`` to keep all checkpoints locally. (default: ``-1``)
Expand Down Expand Up @@ -866,6 +887,7 @@ def __init__(
save_overwrite: bool = False,
save_interval: Union[str, int, Time, Callable[[State, Event], bool]] = '1ep',
save_weights_only: bool = False,
save_ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
save_num_checkpoints_to_keep: int = -1,
save_metrics: bool = False,

Expand Down Expand Up @@ -1150,6 +1172,7 @@ def __init__(
latest_remote_file_name=latest_remote_file_name,
overwrite=save_overwrite,
weights_only=save_weights_only,
ignore_keys=save_ignore_keys,
save_interval=save_interval,
num_checkpoints_to_keep=save_num_checkpoints_to_keep,
)
Expand Down
15 changes: 13 additions & 2 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import torch
from packaging import version
Expand Down Expand Up @@ -938,6 +938,7 @@ def _save_checkpoint(
save_filename: str,
*,
weights_only: bool = False,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
) -> Union[str, None]: # noqa: D103

is_deepspeed = is_model_deepspeed(state.model)
Expand All @@ -957,6 +958,15 @@ def _save_checkpoint(
'rng': reproducibility.get_rng_state(),
}

if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(state_dict)
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

if state.fsdp_sharded_state_dict_enabled:
# To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top
# level of the state dict because the load_sharded_optimizer_state_dict function
Expand Down Expand Up @@ -1087,9 +1097,10 @@ def save_checkpoint(
filename: str = 'ep{epoch}-ba{batch}-rank{rank}',
*,
weights_only: bool = False,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
) -> Union[str, None]: # noqa: D103
save_filename = get_save_filename(state, filename)
return _save_checkpoint(state, save_filename, weights_only=weights_only)
return _save_checkpoint(state, save_filename, weights_only=weights_only, ignore_keys=ignore_keys)


save_checkpoint.__doc__ = f"""Checkpoint the training ``state``.
Expand Down
39 changes: 39 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def test_checkpoint_saver_properly_constructed(self, save_folder: str, expected_
'weights_only': False,
'save_interval': '1ep',
'num_checkpoints_to_keep': -1,
'ignore_keys': None,
}
expected_folder = expected_path.rstrip('/') if expected_path != '' else '.'
mock_checkpoint_saver.assert_called_once_with(folder=expected_folder, **rest_of_checkpoint_saver_kwargs)
Expand Down Expand Up @@ -790,6 +791,44 @@ def test_load_ignore_keys(self, load_ignore_keys, weights_equal, callbacks_equal
assert trainer_1_rng_state is not None
deep_compare(trainer_1_rng_state, trainer_2._rng_state)

@pytest.mark.parametrize('save_ignore_keys,weights_equal,callbacks_equal,rng_equal', [
['*', False, False, False],
['state/model/*', False, True, True],
['state/callbacks/*', True, False, True],
['rng', True, True, False],
])
@pytest.mark.filterwarnings('ignore:.* is not in the state_dict.*:UserWarning')
def test_save_ignore_keys(self, save_ignore_keys, weights_equal, callbacks_equal, rng_equal):

trainer_1 = self.get_trainer(save_folder='first', save_ignore_keys=[save_ignore_keys])
trainer_1.fit()
trainer_1_rng_state = reproducibility.get_rng_state()
trainer_1.close()

last_checkpoint = os.path.join('first', 'ep2.pt')
trainer_2 = self.get_trainer(load_path=last_checkpoint)

# Check weights loaded properly
with contextlib.nullcontext() if weights_equal else pytest.raises(AssertionError):
self._assert_weights_equivalent(
trainer_1.state.model,
trainer_2.state.model,
)

# Check callbacks state
stateful_callbacks_equal = self._stateful_callbacks_equal(
trainer_1.state.callbacks,
trainer_2.state.callbacks,
)
if callbacks_equal:
assert stateful_callbacks_equal
else:
assert not stateful_callbacks_equal

if rng_equal:
assert trainer_1_rng_state is not None
deep_compare(trainer_1_rng_state, trainer_2._rng_state)

@pytest.mark.remote
@device('cpu')
@pytest.mark.parametrize('load_weights_only', [True, False])
Expand Down
1 change: 1 addition & 0 deletions tests/utils/test_autolog_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def test_extract_hparams_trainer():
'save_overwrite': False,
'save_interval': '1ep',
'save_weights_only': False,
'save_ignore_keys': None,
'save_num_checkpoints_to_keep': -1,
'save_metrics': False,

Expand Down

0 comments on commit 31ea664

Please sign in to comment.