Skip to content

Commit 5a15771

Browse files
Yifu Wangfacebook-github-bot
Yifu Wang
authored andcommitted
Switch to json for snapshot metadata (#120)
Summary: Pull Request resolved: #120 When the number of entries in the snapshot metadata is very large, yaml serialization becomes a bottleneck and there's little we can do to optimize. Fortunately, when the metadata is serialized via json it is still valid yaml. Technically, the snapshot metadata format is still yaml. Reviewed By: raypeng Differential Revision: D40626156 fbshipit-source-id: 013e81fd42dd4dd1debf9579c4bee854374e1ca2
1 parent cffb0cc commit 5a15771

File tree

3 files changed

+30
-33
lines changed

3 files changed

+30
-33
lines changed

tests/test_manifest.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import json
98
from dataclasses import asdict
109
from typing import Dict, Generator
1110
from unittest.mock import patch
1211

1312
import pytest
1413

14+
import yaml
15+
1516
from _pytest.fixtures import SubRequest # @manual
1617

1718
from torchsnapshot.manifest import (
@@ -28,6 +29,12 @@
2829
)
2930
from torchsnapshot.manifest_ops import _insert_entry, get_manifest_for_rank
3031

32+
try:
33+
from yaml import CSafeDumper as Dumper
34+
except ImportError:
35+
from yaml import SafeDumper as Dumper
36+
37+
3138
_WORLD_SIZE = 2
3239
_MANIFEST_0: Dict[str, Entry] = {
3340
"0/foo": DictEntry(
@@ -227,17 +234,15 @@
227234
@pytest.fixture(params=[True, False])
228235
def use_cyaml(request: SubRequest) -> Generator[None, None, None]:
229236
if request.param:
230-
from yaml import CSafeDumper, CSafeLoader
237+
from yaml import CSafeLoader
231238

232-
with patch("torchsnapshot.manifest.Dumper", CSafeDumper):
233-
with patch("torchsnapshot.manifest.Loader", CSafeLoader):
234-
yield
239+
with patch("torchsnapshot.manifest.Loader", CSafeLoader):
240+
yield
235241
else:
236-
from yaml import SafeDumper, SafeLoader
242+
from yaml import SafeLoader
237243

238-
with patch("torchsnapshot.manifest.Dumper", SafeDumper):
239-
with patch("torchsnapshot.manifest.Loader", SafeLoader):
240-
yield
244+
with patch("torchsnapshot.manifest.Loader", SafeLoader):
245+
yield
241246

242247

243248
@pytest.mark.usefixtures("use_cyaml")
@@ -255,26 +260,19 @@ def test_manifest_yaml_serialization(manifest: Dict[str, Entry]) -> None:
255260

256261
@pytest.mark.usefixtures("use_cyaml")
257262
@pytest.mark.parametrize("manifest", [_MANIFEST_0, _MANIFEST_1])
258-
def test_manifest_json_serialization(manifest: Dict[str, Entry]) -> None:
263+
def test_manifest_yaml_dumper(manifest: Dict[str, Entry]) -> None:
259264
"""
260-
Verify that when the metadata is serialized via json, it is load-able with
261-
the yaml loader.
262-
263-
When the number of entries in the snapshot metadata is very large, yaml
264-
serialization becomes a bottleneck and there's little we can do to
265-
optimize. We likely need to switch to json to overcome this. Fortunately,
266-
when our metadata is serialized via json, it is compatible with the yaml
267-
loader, so we can make the switch in a backward compatible fashion. This
268-
test makes sure that we don't do anything crazy with the metadata to break
269-
this compatibility.
265+
:func:`SnapshotMetadata.to_yaml` switched to :func:`json.dumps`` to help
266+
with the serialization performance. This test verifies that old snapshot
267+
metadata serialized with :func:`yaml.dump` are still loadable.
270268
"""
271269
metadata = SnapshotMetadata(
272270
version="0.0.0",
273271
world_size=_WORLD_SIZE,
274272
manifest=manifest,
275273
)
276-
yaml_str = metadata.to_yaml()
277-
json_str = json.dumps(asdict(metadata))
274+
yaml_str = yaml.dump(asdict(metadata), sort_keys=False, Dumper=Dumper)
275+
json_str = metadata.to_yaml()
278276
metadata_from_yaml = SnapshotMetadata.from_yaml(yaml_str=yaml_str)
279277
metadata_from_json = SnapshotMetadata.from_yaml(yaml_str=json_str)
280278
assert metadata_from_json == metadata_from_yaml

torchsnapshot/manifest.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-ignore-all-errors[2]: Allow `Any` in type annotations
99

1010
import base64
11+
import json
1112
import logging
1213
import struct
1314
from dataclasses import asdict, dataclass
@@ -17,9 +18,9 @@
1718
import yaml
1819

1920
try:
20-
from yaml import CSafeDumper as Dumper, CSafeLoader as Loader
21+
from yaml import CSafeLoader as Loader
2122
except ImportError:
22-
from yaml import SafeDumper as Dumper, SafeLoader as Loader
23+
from yaml import SafeLoader as Loader
2324

2425
logger: logging.Logger = logging.getLogger(__name__)
2526

@@ -280,7 +281,12 @@ class SnapshotMetadata:
280281
manifest: Manifest
281282

282283
def to_yaml(self) -> str:
283-
return yaml.dump(asdict(self), sort_keys=False, Dumper=Dumper)
284+
# When the number of entries in the snapshot metadata is large, yaml
285+
# serialization becomes slow and there's little room for optimization.
286+
# Since the snapshot metadata can be dumped as json and json is a
287+
# subset of yaml, using json.dumps() here to help with the
288+
# serialization performance without needing to deprecate yaml.
289+
return json.dumps(asdict(self), sort_keys=False, indent=2)
284290

285291
@classmethod
286292
def from_yaml(cls, yaml_str: str) -> "SnapshotMetadata":

torchsnapshot/snapshot.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,17 @@
3131
from .dist_store import get_or_create_store, LinearBarrier
3232

3333
from .flatten import flatten, inflate
34-
from .io_preparer import (
35-
ObjectBufferConsumer,
36-
prepare_read,
37-
prepare_write,
38-
TensorIOPreparer,
39-
)
34+
from .io_preparer import prepare_read, prepare_write
4035
from .io_types import ReadIO, ReadReq, StoragePlugin, WriteIO, WriteReq
4136
from .knobs import is_batching_disabled
4237

4338
from .manifest import (
44-
ChunkedTensorEntry,
4539
Entry,
4640
is_container_entry,
4741
Manifest,
4842
PrimitiveEntry,
4943
ShardedTensorEntry,
5044
SnapshotMetadata,
51-
TensorEntry,
5245
)
5346
from .manifest_ops import get_manifest_for_rank
5447
from .partitioner import consolidate_replicated_entries, partition_write_reqs

0 commit comments

Comments
 (0)