@@ -189,7 +189,7 @@ def unwrap(self):
189189 return module
190190
191191
192- def get_param_info (optim : Optimizer , model : torch . nn . Module ):
192+ def get_param_info (optim : Optimizer ):
193193 # Get a backup of necessary information of parameters for future use, which includes:
194194 # 1. A complete param_group, with params in the form of param_id
195195 # 2. A mapping from param address (obtained using id(param)) to integer param_id
@@ -204,8 +204,6 @@ def get_param_info(optim: Optimizer, model: torch.nn.Module):
204204 "param2id" : {},
205205 "id2param" : {},
206206 "param2shape" : {},
207- "old_input_embedding_param_id" : None ,
208- "old_output_embedding_param_id" : None ,
209207 }
210208 start_index = 0
211209 for group in optim .param_groups :
@@ -222,13 +220,6 @@ def get_param_info(optim: Optimizer, model: torch.nn.Module):
222220 param_info ["param_groups" ].append (packed_group )
223221 start_index += len (group ["params" ])
224222
225- input_embedding = model .get_input_embeddings ()
226- if input_embedding is not None :
227- param_info ["old_input_embedding_param_id" ] = id (input_embedding .weight )
228- output_embedding = model .get_output_embeddings ()
229- if output_embedding is not None :
230- param_info ["old_output_embedding_param_id" ] = id (output_embedding .weight )
231-
232223 return param_info
233224
234225
@@ -1090,32 +1081,6 @@ def __del__(self):
10901081 """Destroy the process groups in ProcessGroupMesh"""
10911082 self .pg_mesh .destroy_mesh_process_groups ()
10921083
1093- def set_resized_embedding_to_optimizer (self , model , optimizer , param_info ):
1094- old_input_embedding_param_id = param_info ["old_input_embedding_param_id" ]
1095- if old_input_embedding_param_id is not None :
1096- for param_group in optimizer .param_groups :
1097- group_params = param_group ["params" ]
1098- new_params = []
1099- for param in group_params :
1100- if id (param ) == old_input_embedding_param_id :
1101- new_input_embeddings = model .module .get_input_embeddings ()
1102- new_params .append (new_input_embeddings .weight )
1103- else :
1104- new_params .append (param )
1105- param_group ["params" ] = new_params
1106- old_output_embedding_param_id = param_info ["old_output_embedding_param_id" ]
1107- if old_output_embedding_param_id is not None :
1108- for param_group in optimizer .param_groups :
1109- group_params = param_group ["params" ]
1110- new_params = []
1111- for param in group_params :
1112- if id (param ) == old_output_embedding_param_id :
1113- new_output_embeddings = model .module .get_output_embeddings ()
1114- new_params .append (new_output_embeddings .weight )
1115- else :
1116- new_params .append (param )
1117- param_group ["params" ] = new_params
1118-
11191084 @property
11201085 def enable_pipeline_parallelism (self ) -> bool :
11211086 return self .pp_size > 1
@@ -1146,7 +1111,7 @@ def configure(
11461111 dataloader : Optional [DataLoader ] = None ,
11471112 lr_scheduler : Optional [LRScheduler ] = None ,
11481113 ) -> Tuple [Module , OptimizerWrapper , Callable , DataLoader , LRScheduler ]:
1149- param_info = get_param_info (optimizer , model )
1114+ param_info = get_param_info (optimizer )
11501115 if not isinstance (model , ModelWrapper ):
11511116 use_ddp = self .dp_size > 1 and self .pp_size == 1 and self .zero_stage == 0
11521117 model = HybridParallelModule (
@@ -1160,7 +1125,6 @@ def configure(
11601125 custom_policy = self .custom_policy ,
11611126 )
11621127
1163- self .set_resized_embedding_to_optimizer (model , optimizer , param_info )
11641128 if optimizer is not None and not isinstance (optimizer , OptimizerWrapper ):
11651129 if self .zero_stage == 0 :
11661130 if self .precision in ["fp16" , "bf16" ]:
0 commit comments