5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
7
8
- import json
9
8
from dataclasses import asdict
10
9
from typing import Dict , Generator
11
10
from unittest .mock import patch
12
11
13
12
import pytest
14
13
14
+ import yaml
15
+
15
16
from _pytest .fixtures import SubRequest # @manual
16
17
17
18
from torchsnapshot .manifest import (
28
29
)
29
30
from torchsnapshot .manifest_ops import _insert_entry , get_manifest_for_rank
30
31
32
+ try :
33
+ from yaml import CSafeDumper as Dumper
34
+ except ImportError :
35
+ from yaml import SafeDumper as Dumper
36
+
37
+
31
38
_WORLD_SIZE = 2
32
39
_MANIFEST_0 : Dict [str , Entry ] = {
33
40
"0/foo" : DictEntry (
227
234
@pytest .fixture (params = [True , False ])
228
235
def use_cyaml (request : SubRequest ) -> Generator [None , None , None ]:
229
236
if request .param :
230
- from yaml import CSafeDumper , CSafeLoader
237
+ from yaml import CSafeLoader
231
238
232
- with patch ("torchsnapshot.manifest.Dumper" , CSafeDumper ):
233
- with patch ("torchsnapshot.manifest.Loader" , CSafeLoader ):
234
- yield
239
+ with patch ("torchsnapshot.manifest.Loader" , CSafeLoader ):
240
+ yield
235
241
else :
236
- from yaml import SafeDumper , SafeLoader
242
+ from yaml import SafeLoader
237
243
238
- with patch ("torchsnapshot.manifest.Dumper" , SafeDumper ):
239
- with patch ("torchsnapshot.manifest.Loader" , SafeLoader ):
240
- yield
244
+ with patch ("torchsnapshot.manifest.Loader" , SafeLoader ):
245
+ yield
241
246
242
247
243
248
@pytest .mark .usefixtures ("use_cyaml" )
@@ -255,26 +260,19 @@ def test_manifest_yaml_serialization(manifest: Dict[str, Entry]) -> None:
255
260
256
261
@pytest .mark .usefixtures ("use_cyaml" )
257
262
@pytest .mark .parametrize ("manifest" , [_MANIFEST_0 , _MANIFEST_1 ])
258
- def test_manifest_json_serialization (manifest : Dict [str , Entry ]) -> None :
263
+ def test_manifest_yaml_dumper (manifest : Dict [str , Entry ]) -> None :
259
264
"""
260
- Verify that when the metadata is serialized via json, it is load-able with
261
- the yaml loader.
262
-
263
- When the number of entries in the snapshot metadata is very large, yaml
264
- serialization becomes a bottleneck and there's little we can do to
265
- optimize. We likely need to switch to json to overcome this. Fortunately,
266
- when our metadata is serialized via json, it is compatible with the yaml
267
- loader, so we can make the switch in a backward compatible fashion. This
268
- test makes sure that we don't do anything crazy with the metadata to break
269
- this compatibility.
265
+ :func:`SnapshotMetadata.to_yaml` switched to :func:`json.dumps`` to help
266
+ with the serialization performance. This test verifies that old snapshot
267
+ metadata serialized with :func:`yaml.dump` are still loadable.
270
268
"""
271
269
metadata = SnapshotMetadata (
272
270
version = "0.0.0" ,
273
271
world_size = _WORLD_SIZE ,
274
272
manifest = manifest ,
275
273
)
276
- yaml_str = metadata . to_yaml ( )
277
- json_str = json . dumps ( asdict ( metadata ) )
274
+ yaml_str = yaml . dump ( asdict ( metadata ), sort_keys = False , Dumper = Dumper )
275
+ json_str = metadata . to_yaml ( )
278
276
metadata_from_yaml = SnapshotMetadata .from_yaml (yaml_str = yaml_str )
279
277
metadata_from_json = SnapshotMetadata .from_yaml (yaml_str = json_str )
280
278
assert metadata_from_json == metadata_from_yaml
0 commit comments