6
6
from abc import ABCMeta , abstractmethod
7
7
from collections import OrderedDict
8
8
from pathlib import Path
9
- from typing import Any , Callable , cast , Dict , List , Mapping , NamedTuple , Optional , Tuple , Union
9
+ from typing import Any , Callable , cast , Dict , List , Mapping , NamedTuple , Optional , Union
10
10
11
11
import torch
12
12
import torch .nn as nn
@@ -277,7 +277,7 @@ class Checkpoint(Serializable):
277
277
"""
278
278
279
279
Item = NamedTuple ("Item" , [("priority" , int ), ("filename" , str )])
280
- _state_dict_all_req_keys = ("saved " ,)
280
+ _state_dict_all_req_keys = ("_saved " ,)
281
281
282
282
def __init__ (
283
283
self ,
@@ -707,11 +707,12 @@ def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, *
707
707
708
708
Checkpoint .load_objects (to_load = to_load , checkpoint = path , ** load_kwargs )
709
709
710
- def state_dict (self ) -> OrderedDict [ str , List [ Tuple [ int , str ]]] :
710
+ def state_dict (self ) -> OrderedDict :
711
711
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
712
712
Can be used to save internal state of the class.
713
713
"""
714
- return OrderedDict ([("saved" , [(p , f ) for p , f in self ._saved ])])
714
+ # TODO: this method should use _state_dict_all_req_keys
715
+ return OrderedDict ([("_saved" , [(p , f ) for p , f in self ._saved ])])
715
716
716
717
def load_state_dict (self , state_dict : Mapping ) -> None :
717
718
"""Method replaces internal state of the class with provided state dict data.
@@ -720,7 +721,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
720
721
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
721
722
"""
722
723
super ().load_state_dict (state_dict )
723
- self ._saved = [Checkpoint .Item (p , f ) for p , f in state_dict ["saved " ]]
724
+ self ._saved = [Checkpoint .Item (p , f ) for p , f in state_dict ["_saved " ]]
724
725
725
726
@staticmethod
726
727
def get_default_score_fn (metric_name : str , score_sign : float = 1.0 ) -> Callable :
0 commit comments