@@ -87,8 +87,8 @@ class Manager:
8787 def __init__ (
8888 self ,
8989 pg : "ProcessGroup" ,
90- load_state_dict : Callable [[T ], None ],
91- state_dict : Callable [[], T ],
90+ load_state_dict : Optional [ Callable [[T ], None ] ],
91+ state_dict : Optional [ Callable [[], T ] ],
9292 min_replica_size : int ,
9393 use_async_quorum : bool = True ,
9494 timeout : timedelta = timedelta (seconds = 60 ),
@@ -144,7 +144,7 @@ def __init__(
144144 transfering checkpoints to recovering replicas
145145 """
146146 self ._load_state_dict = load_state_dict
147- self ._state_dict = state_dict
147+ self ._user_state_dict = state_dict
148148 self ._pending_state_dict : Optional [Dict [str , object ]] = None
149149 self ._use_async_quorum = use_async_quorum
150150 self ._timeout = timeout
@@ -159,8 +159,6 @@ def __init__(
159159 world_size = world_size or int (os .environ ["WORLD_SIZE" ])
160160 self ._min_replica_size = min_replica_size
161161
162- self ._user_state_dict = state_dict
163-
164162 if checkpoint_transport is None :
165163 checkpoint_transport = CheckpointServer [Dict [str , T ]](
166164 timeout = timeout ,
@@ -226,6 +224,12 @@ def __init__(
226224 self ._participating_rank : Optional [int ] = None
227225 self ._participating_world_size : int = 0
228226
227+ def set_state_dict_fns (
228+ self , load_state_dict : Callable [[T ], None ], state_dict : Callable [[], T ]
229+ ) -> None :
230+ self ._load_state_dict = load_state_dict
231+ self ._user_state_dict = state_dict
232+
229233 def shutdown (self , wait : bool = True ) -> None :
230234 """
231235 Shutdown the manager and checkpoint server.
@@ -531,8 +535,12 @@ def _apply_pending_state_dict(self) -> None:
531535 self ._logger .info ("applying pending state dict" )
532536
533537 assert self ._pending_state_dict is not None , "checkpoint was not staged"
538+ assert (
539+ self ._load_state_dict is not None
540+ ), "user load_state_dict is not initialized."
534541 self ._load_state_dict (self ._pending_state_dict ["user" ])
535542 self ._pending_state_dict = None
543+ self ._logger .info ("Loaded state dict." )
536544
537545 def should_commit (self , timeout : Optional [timedelta ] = None ) -> bool :
538546 """
@@ -602,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
602610 self ._batches_committed = state_dict ["batches_committed" ]
603611
604612 def _manager_state_dict (self ) -> Dict [str , object ]:
613+ assert self ._user_state_dict is not None , "user state_dict is not initialized."
605614 return {
606615 "user" : self ._user_state_dict (),
607616 "torchft" : self .state_dict (),
0 commit comments