Skip to content

Switch to json for snapshot metadata #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from dataclasses import asdict
from typing import Dict, Generator
from unittest.mock import patch

import pytest

import yaml

from _pytest.fixtures import SubRequest # @manual

from torchsnapshot.manifest import (
Expand All @@ -28,6 +29,12 @@
)
from torchsnapshot.manifest_ops import _insert_entry, get_manifest_for_rank

try:
from yaml import CSafeDumper as Dumper
except ImportError:
from yaml import SafeDumper as Dumper


_WORLD_SIZE = 2
_MANIFEST_0: Dict[str, Entry] = {
"0/foo": DictEntry(
Expand Down Expand Up @@ -227,17 +234,15 @@
@pytest.fixture(params=[True, False])
def use_cyaml(request: SubRequest) -> Generator[None, None, None]:
if request.param:
from yaml import CSafeDumper, CSafeLoader
from yaml import CSafeLoader

with patch("torchsnapshot.manifest.Dumper", CSafeDumper):
with patch("torchsnapshot.manifest.Loader", CSafeLoader):
yield
with patch("torchsnapshot.manifest.Loader", CSafeLoader):
yield
else:
from yaml import SafeDumper, SafeLoader
from yaml import SafeLoader

with patch("torchsnapshot.manifest.Dumper", SafeDumper):
with patch("torchsnapshot.manifest.Loader", SafeLoader):
yield
with patch("torchsnapshot.manifest.Loader", SafeLoader):
yield


@pytest.mark.usefixtures("use_cyaml")
Expand All @@ -255,26 +260,19 @@ def test_manifest_yaml_serialization(manifest: Dict[str, Entry]) -> None:

@pytest.mark.usefixtures("use_cyaml")
@pytest.mark.parametrize("manifest", [_MANIFEST_0, _MANIFEST_1])
def test_manifest_json_serialization(manifest: Dict[str, Entry]) -> None:
def test_manifest_yaml_dumper(manifest: Dict[str, Entry]) -> None:
"""
Verify that when the metadata is serialized via json, it is load-able with
the yaml loader.

When the number of entries in the snapshot metadata is very large, yaml
serialization becomes a bottleneck and there's little we can do to
optimize. We likely need to switch to json to overcome this. Fortunately,
when our metadata is serialized via json, it is compatible with the yaml
loader, so we can make the switch in a backward compatible fashion. This
test makes sure that we don't do anything crazy with the metadata to break
this compatibility.
:func:`SnapshotMetadata.to_yaml` switched to :func:`json.dumps`` to help
with the serialization performance. This test verifies that old snapshot
metadata serialized with :func:`yaml.dump` are still loadable.
"""
metadata = SnapshotMetadata(
version="0.0.0",
world_size=_WORLD_SIZE,
manifest=manifest,
)
yaml_str = metadata.to_yaml()
json_str = json.dumps(asdict(metadata))
yaml_str = yaml.dump(asdict(metadata), sort_keys=False, Dumper=Dumper)
json_str = metadata.to_yaml()
metadata_from_yaml = SnapshotMetadata.from_yaml(yaml_str=yaml_str)
metadata_from_json = SnapshotMetadata.from_yaml(yaml_str=json_str)
assert metadata_from_json == metadata_from_yaml
Expand Down
12 changes: 9 additions & 3 deletions torchsnapshot/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-ignore-all-errors[2]: Allow `Any` in type annotations

import base64
import json
import logging
import struct
from dataclasses import asdict, dataclass
Expand All @@ -17,9 +18,9 @@
import yaml

try:
from yaml import CSafeDumper as Dumper, CSafeLoader as Loader
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeDumper as Dumper, SafeLoader as Loader
from yaml import SafeLoader as Loader

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -280,7 +281,12 @@ class SnapshotMetadata:
manifest: Manifest

def to_yaml(self) -> str:
return yaml.dump(asdict(self), sort_keys=False, Dumper=Dumper)
# When the number of entries in the snapshot metadata is large, yaml
# serialization becomes slow and there's little room for optimization.
# Since the snapshot metadata can be dumped as json and json is a
# subset of yaml, using json.dumps() here to help with the
# serialization performance without needing to deprecate yaml.
return json.dumps(asdict(self), sort_keys=False, indent=2)

@classmethod
def from_yaml(cls, yaml_str: str) -> "SnapshotMetadata":
Expand Down
9 changes: 1 addition & 8 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,17 @@
from .dist_store import get_or_create_store, LinearBarrier

from .flatten import flatten, inflate
from .io_preparer import (
ObjectBufferConsumer,
prepare_read,
prepare_write,
TensorIOPreparer,
)
from .io_preparer import prepare_read, prepare_write
from .io_types import ReadIO, ReadReq, StoragePlugin, WriteIO, WriteReq
from .knobs import is_batching_disabled

from .manifest import (
ChunkedTensorEntry,
Entry,
is_container_entry,
Manifest,
PrimitiveEntry,
ShardedTensorEntry,
SnapshotMetadata,
TensorEntry,
)
from .manifest_ops import get_manifest_for_rank
from .partitioner import consolidate_replicated_entries, partition_write_reqs
Expand Down