Skip to content

Commit 698fcbf

Browse files
committed
Few other fixes
1 parent 0cd0054 commit 698fcbf

File tree

5 files changed

+9
-12
lines changed

5 files changed

+9
-12
lines changed

ignite/handlers/checkpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import ABCMeta, abstractmethod
77
from collections import OrderedDict
88
from pathlib import Path
9-
from typing import Any, Callable, cast, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
9+
from typing import Any, Callable, cast, Dict, List, Mapping, NamedTuple, Optional, Union
1010

1111
import torch
1212
import torch.nn as nn
@@ -277,7 +277,7 @@ class Checkpoint(Serializable):
277277
"""
278278

279279
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
280-
_state_dict_all_req_keys = ("saved",)
280+
_state_dict_all_req_keys = ("_saved",)
281281

282282
def __init__(
283283
self,
@@ -707,11 +707,12 @@ def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, *
707707

708708
Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs)
709709

710-
def state_dict(self) -> OrderedDict[str, List[Tuple[int, str]]]:
710+
def state_dict(self) -> OrderedDict:
711711
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
712712
Can be used to save internal state of the class.
713713
"""
714-
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])
714+
# TODO: this method should use _state_dict_all_req_keys
715+
return OrderedDict([("_saved", [(p, f) for p, f in self._saved])])
715716

716717
def load_state_dict(self, state_dict: Mapping) -> None:
717718
"""Method replaces internal state of the class with provided state dict data.
@@ -720,7 +721,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
720721
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
721722
"""
722723
super().load_state_dict(state_dict)
723-
self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]]
724+
self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["_saved"]]
724725

725726
@staticmethod
726727
def get_default_score_fn(metric_name: str, score_sign: float = 1.0) -> Callable:

ignite/metrics/metrics_lambda.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import itertools
2-
from collections import OrderedDict
3-
from collections.abc import Mapping
42
from typing import Any, Callable, Optional, Union
53

64
import torch

ignite/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ def value(self) -> Any:
111111
return self.collection[self.key] # type: ignore[index]
112112

113113
@staticmethod
114-
def wrap(
115-
object: Union[Dict, List], key: Union[int, str], value: Any
116-
) -> Union[Any, "_CollectionItem"]:
114+
def wrap(object: Union[Dict, List], key: Union[int, str], value: Any) -> Union[Any, "_CollectionItem"]:
117115
return (
118116
_CollectionItem(object, key)
119117
if value is None or isinstance(value, _CollectionItem.types_as_collection_item)

tests/ignite/handlers/test_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1643,7 +1643,7 @@ def test_checkpoint_load_state_dict():
16431643
to_save = {"model": model}
16441644
checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None)
16451645

1646-
sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]}
1646+
sd = {"_saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]}
16471647
checkpointer.load_state_dict(sd)
16481648
assert checkpointer._saved == true_checkpointer._saved
16491649

tests/ignite/metrics/test_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def _test_creating_on_xla_fails(device):
710710
@pytest.mark.distributed
711711
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
712712
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
713-
def test_distrib_nccl_gpu(distributed):
713+
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
714714
device = idist.device()
715715
_test_distrib_sync_all_reduce_decorator(device)
716716
_test_invalid_sync_all_reduce(device)

0 commit comments

Comments
 (0)