@@ -360,24 +360,20 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0:
360
360
destination = hook_result
361
361
return destination
362
362
363
- def _save_to_state_dict (self , destination , prefix , keep_vars , only_rank_0 = True ):
364
- r"""Saves module state to `destination` dictionary, containing a state
365
- of the module, but not its descendants. This is called on every
366
- submodule in :meth:`~torch.nn.Module.state_dict`.
367
-
368
- In rare cases, subclasses can achieve class-specific behavior by
369
- overriding this method with custom logic.
363
+ def _get_param_to_save_data (self , param_list : List [torch .nn .Parameter ], only_rank_0 : bool ) -> Dict :
364
+ """
365
+ get param content from chunks.
370
366
371
367
Args:
372
- destination (dict): a dict where state will be stored
373
- prefix (str): the prefix for parameters and buffers used in this
374
- module
375
- """
376
- assert keep_vars is False , "`state_dict` with parameter, `keep_vars=True`, is not supported now."
368
+ param_list (_type_): a list of torch.nn.Parameters
369
+ only_rank_0 (_type_): _description_
377
370
371
+ Returns:
372
+ Dict: a dict whose key is param name and value is param with correct payload
373
+ """
378
374
# save parameters
379
375
param_to_save_data = dict ()
380
- chunk_list = self .chunk_manager .get_chunks (self . fp32_params )
376
+ chunk_list = self .chunk_manager .get_chunks (param_list )
381
377
for chunk in chunk_list :
382
378
temp_chunk = get_temp_total_chunk_on_cuda (chunk )
383
379
@@ -391,7 +387,37 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
391
387
param_to_save_data [tensor ] = record_tensor
392
388
393
389
del temp_chunk
390
+ return param_to_save_data
391
+
392
+ def torch_named_parameters (self ):
393
+ """
394
+ get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
395
+ It works the same as torch.Module named_parameters
396
+ """
397
+ params_list = [p for p in self .parameters (recurse = True )]
398
+ param_to_save_data = self ._get_param_to_save_data (params_list , False )
399
+ for (name , _ ), p in zip (self .named_parameters (recurse = True ), params_list ):
400
+ if p is not None :
401
+ assert p in param_to_save_data , "Parameter '{}' is neglected in the chunk list" .format (name )
402
+ record_parameter = param_to_save_data [p ]
403
+ yield name , record_parameter
404
+
405
+ def _save_to_state_dict (self , destination , prefix , keep_vars , only_rank_0 = True ):
406
+ r"""Saves module state to `destination` dictionary, containing a state
407
+ of the module, but not its descendants. This is called on every
408
+ submodule in :meth:`~torch.nn.Module.state_dict`.
409
+
410
+ In rare cases, subclasses can achieve class-specific behavior by
411
+ overriding this method with custom logic.
412
+
413
+ Args:
414
+ destination (dict): a dict where state will be stored
415
+ prefix (str): the prefix for parameters and buffers used in this
416
+ module
417
+ """
418
+ assert keep_vars is False , "`state_dict` with parameter, `keep_vars=True`, is not supported now."
394
419
420
+ param_to_save_data = self ._get_param_to_save_data (self .fp32_params , only_rank_0 )
395
421
for (name , p ), fp32_p in zip (self .named_parameters (), self .fp32_params ):
396
422
if p is not None :
397
423
assert fp32_p in param_to_save_data , "Parameter '{}' is neglected in the chunk list" .format (name )
0 commit comments