@@ -42,6 +42,27 @@ def log_execution_time(logger: logging.Logger, operation_name: str):
42
42
logger .info (f'{ operation_name } took { end_time - start_time :.2f} seconds' )
43
43
44
44
45
+ @contextmanager
46
+ def get_full_state_dict (model : torch .nn .Module ):
47
+ """Context manager to temporarily get full state dict regardless of should_save_peft_only setting for huggingface models.
48
+
49
+ PEFT models with lora have an updated state_dict fn (in composer/models/huggingface.py) that
50
+ returns the state_dict with only the lora params if should_save_peft_only is True.
51
+ But when we're syncing module states, we need the full state dict, so we temporarily set
52
+ should_save_peft_only to False.
53
+ """
54
+ # TODO: Since sharding peft/lora weights can be inefficient due to their small sizes (leading to communication overhead
55
+ # outweighing memory savings), we should provide an interface that allows users to avoid sharding these weights.
56
+ original_setting = getattr (model , 'should_save_peft_only' , None )
57
+ if original_setting is not None :
58
+ model .should_save_peft_only = False # type: ignore
59
+ try :
60
+ yield
61
+ finally :
62
+ if original_setting is not None :
63
+ model .should_save_peft_only = original_setting # type: ignore
64
+
65
+
45
66
def _check_duplicate_modules (model : torch .nn .Module ) -> None :
46
67
"""Checks whether the model has duplicate module references.
47
68
@@ -98,7 +119,8 @@ def _parallelize_model_helper(
98
119
full_state_dict = True ,
99
120
cpu_offload = True ,
100
121
)
101
- full_state_dict = get_model_state_dict (model , options = options )
122
+ with get_full_state_dict (model ):
123
+ full_state_dict = get_model_state_dict (model , options = options )
102
124
103
125
with log_execution_time (log , 'Prepare FSDP2' ):
104
126
prepare_fully_shard (model , config , precision , fsdp_wrap_policy )
0 commit comments