Skip to content

Commit beed9b1

Browse files
committed
Merge branch 'main' into release/v0.32.0
2 parents 912e412 + 0e59217 commit beed9b1

File tree

6 files changed

+64
-30
lines changed

6 files changed

+64
-30
lines changed

composer/distributed/prepare_distributed.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,27 @@ def log_execution_time(logger: logging.Logger, operation_name: str):
4242
logger.info(f'{operation_name} took {end_time - start_time:.2f} seconds')
4343

4444

45+
@contextmanager
46+
def get_full_state_dict(model: torch.nn.Module):
47+
"""Context manager to temporarily get full state dict regardless of should_save_peft_only setting for huggingface models.
48+
49+
PEFT models with lora have an updated state_dict fn (in composer/models/huggingface.py) that
50+
returns the state_dict with only the lora params if should_save_peft_only is True.
51+
But when we're syncing module states, we need the full state dict, so we temporarily set
52+
should_save_peft_only to False.
53+
"""
54+
# TODO: Since sharding peft/lora weights can be inefficient due to their small sizes (leading to communication overhead
55+
# outweighing memory savings), we should provide an interface that allows users to avoid sharding these weights.
56+
original_setting = getattr(model, 'should_save_peft_only', None)
57+
if original_setting is not None:
58+
model.should_save_peft_only = False # type: ignore
59+
try:
60+
yield
61+
finally:
62+
if original_setting is not None:
63+
model.should_save_peft_only = original_setting # type: ignore
64+
65+
4566
def _check_duplicate_modules(model: torch.nn.Module) -> None:
4667
"""Checks whether the model has duplicate module references.
4768
@@ -98,7 +119,8 @@ def _parallelize_model_helper(
98119
full_state_dict=True,
99120
cpu_offload=True,
100121
)
101-
full_state_dict = get_model_state_dict(model, options=options)
122+
with get_full_state_dict(model):
123+
full_state_dict = get_model_state_dict(model, options=options)
102124

103125
with log_execution_time(log, 'Prepare FSDP2'):
104126
prepare_fully_shard(model, config, precision, fsdp_wrap_policy)

composer/trainer/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,6 +1701,12 @@ def __init__(
17011701
log.info('No previous autoresume checkpoint found')
17021702
# Actually load the checkpoint from potentially updated arguments
17031703
if load_path is not None:
1704+
# If we are using FSDP and load_monolith_rank0_only is True, then the state_dict must be `full`
1705+
# when we are loading a checkpoint
1706+
if self.state.fsdp_config and self.state.fsdp_config.load_monolith_rank0_only: # type: ignore
1707+
err_msg = 'state_dict_type must be `full` when load_monolith_rank0_only is True when loading a checkpoint'
1708+
assert self.state.fsdp_config.state_dict_type == 'full', err_msg # type: ignore
1709+
17041710
log.info(f'Loading checkpoint from {load_path}')
17051711
if load_object_store is None:
17061712
load_object_store = maybe_create_object_store_from_uri(load_path)

composer/utils/parallelism.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ class FSDP2Config:
7373
reshard_after_forward (Union[bool, int]): Controls parameter behavior after forward.
7474
activation_checkpointing (bool): Whether to use activation checkpointing. Defaults to False.
7575
activation_cpu_offload (bool): Whether to use activation CPU offloading. Defaults to False.
76-
load_monolith_rank0_only (bool): Whether to load monolithic checkpoints on rank 0 only. Defaults to False.
7776
state_dict_type (str): Type of state dict to use. Can be 'full' or 'sharded'. Defaults to 'sharded'.
77+
- Note: In cases where `load_path` is not set in Trainer, `state_dict_type` indicates how a model will be saved.
78+
- Note: In cases where `load_path` is set in Trainer, `state_dict_type` indicates how a model will be loaded and also saved.
79+
load_monolith_rank0_only (bool): Whether to load monolithic checkpoints on rank 0 only. Defaults to False.
80+
- Note: when `load_monolith_rank0_only` is True and `load_path` is set in `Trainer`, `state_dict_type` must be 'full'.
81+
mixed_precision (str): Mixed precision to use. Can be 'DEFAULT', 'PURE', or 'FULL'. Defaults to 'DEFAULT'.
7882
verbose (bool): Whether to print verbose output. Defaults to False.
7983
"""
8084

@@ -169,15 +173,6 @@ def use_orig_params(self) -> bool:
169173
def __post_init__(self):
170174
warnings.warn('FSDP2 Config/APIs are experimental and subject to heavy changes', UserWarning)
171175

172-
# TODO: We might not need `load_monolith_rank0_only` as we can theoretically use
173-
# self.monolith_rank0_only = self.state_dict_type == 'full' assuming that saving
174-
# the model doesn't get affected by `load_monolith_rank0_only`
175-
if self.load_monolith_rank0_only and self.state_dict_type != 'full':
176-
raise ValueError(
177-
'load_monolith_rank0_only=True requires state_dict_type="full". '
178-
f'Got state_dict_type="{self.state_dict_type}"',
179-
)
180-
181176

182177
@dataclass
183178
class TPConfig:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def package_files(prefix: str, directory: str, extension: str):
174174

175175
extra_deps['nlp'] = [
176176
'transformers>=4.11,!=4.34.0,<4.54',
177-
'datasets>=2.4,<4',
177+
'datasets>=2.4,<5',
178178
'huggingface-hub>=0.21.2,<0.34',
179179
]
180180

tests/trainer/test_fsdp2.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import torch.distributed.fsdp
1111
from torch.distributed._tensor import DTensor
1212
from torch.utils.data import DataLoader
13+
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
1314

1415
from composer.models import ComposerClassifier
16+
from composer.models.huggingface import HuggingFaceModel
1517
from composer.trainer.trainer import Trainer
1618
from composer.utils import dist, load_checkpoint
1719
from composer.utils.parallelism import FSDP2Config, FSDPConfig, ParallelismConfig
@@ -815,3 +817,30 @@ def validate_reduce_dtype(module):
815817

816818
for handle in hook_handles:
817819
handle.remove()
820+
821+
822+
@pytest.mark.gpu
823+
@world_size(2)
824+
def test_fsdp2_with_peft_model_and_mixed_init(
825+
world_size: int,
826+
tiny_gpt2_model,
827+
tiny_gpt2_tokenizer,
828+
gpt2_peft_config,
829+
):
830+
del world_size
831+
resolved_device = 'cuda' if dist.get_local_rank() == 0 else 'meta'
832+
model = HuggingFaceModel(
833+
tiny_gpt2_model,
834+
tokenizer=tiny_gpt2_tokenizer,
835+
peft_config=gpt2_peft_config,
836+
should_save_peft_only=True,
837+
)
838+
for module in model.model.modules():
839+
if isinstance(module, GPT2Block):
840+
module._fsdp_wrap = True # type: ignore
841+
model.to(resolved_device)
842+
843+
create_trainer_with_model(
844+
model=model, # type: ignore
845+
use_fsdp2=True,
846+
)

tests/trainer/test_fsdp2_config.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,3 @@ def test_fsdp2config_from_fsdp1_multiple_invalid_attributes():
8787
assert any('invalid_attribute2: value2' in msg for msg in warning_messages)
8888
assert any('auto_wrap: True' in msg for msg in warning_messages)
8989
assert any('sync_module_states: True' in msg for msg in warning_messages)
90-
91-
92-
def test_fsdp2_config_monolithic_validation():
93-
"""Test FSDP2Config validation for monolithic checkpointing."""
94-
# Test valid monolithic config
95-
config = FSDP2Config(
96-
state_dict_type='full',
97-
load_monolith_rank0_only=True,
98-
)
99-
assert config.state_dict_type == 'full'
100-
assert config.load_monolith_rank0_only is True
101-
102-
# Test invalid monolithic config
103-
with pytest.raises(ValueError, match='load_monolith_rank0_only=True requires state_dict_type="full"'):
104-
FSDP2Config(
105-
state_dict_type='sharded',
106-
load_monolith_rank0_only=True,
107-
)

0 commit comments

Comments
 (0)