Skip to content

Commit 55e90d1

Browse files
Rodrigo Kumperapytorchmergebot
authored andcommitted
Handle torch.memory_format serialization in TensorProperties.
The current code deals with TensorProperties serialization in ShardedTensorMetadata, this force using TensorProperties anywhere else to copy the serialization workaround for torch.memory_format. By moving the workaround to TensorProperties itself, we make the type more modular and reusable. Pull Request resolved: pytorch#76679 Approved by: https://github.com/pritamdamania87
1 parent bc3c7a6 commit 55e90d1

File tree

1 file changed

+21
-32
lines changed

1 file changed

+21
-32
lines changed

torch/distributed/_shard/sharded_tensor/metadata.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torch.distributed._shard.metadata import ShardMetadata
77

8-
98
class MEM_FORMAT_ENCODING(Enum):
109
TORCH_CONTIGUOUS_FORMAT = 0
1110
TORCH_CHANNELS_LAST = 1
@@ -22,28 +21,9 @@ class TensorProperties(object):
2221
memory_format: torch.memory_format = field(default=torch.contiguous_format)
2322
pin_memory: bool = False
2423

25-
@dataclass
26-
class ShardedTensorMetadata(object):
27-
"""
28-
Represents metadata for :class:`ShardedTensor`
29-
"""
30-
31-
# Metadata about each shard of the Tensor
32-
shards_metadata: List[ShardMetadata] = field(default_factory=list)
33-
34-
# Size of each dim of the overall Tensor.
35-
size: torch.Size = field(default=torch.Size([]))
36-
37-
tensor_properties: TensorProperties = field(
38-
default=TensorProperties(dtype=torch.get_default_dtype(),
39-
layout=torch.strided,
40-
requires_grad=False,
41-
memory_format=torch.contiguous_format,
42-
pin_memory=False))
43-
4424
def __getstate__(self):
4525
# Since torch.memory_format cannot be pickled!
46-
memory_format = self.tensor_properties.memory_format
26+
memory_format = self.memory_format
4727
if memory_format == torch.contiguous_format:
4828
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
4929
elif memory_format == torch.channels_last:
@@ -53,22 +33,19 @@ def __getstate__(self):
5333
else:
5434
raise RuntimeError(f'Invalid torch.memory_format: {memory_format}')
5535

56-
# Keep old serialization to ensure backward compatibility
5736
return (
58-
self.shards_metadata,
59-
self.size,
60-
self.tensor_properties.dtype,
61-
self.tensor_properties.layout,
62-
self.tensor_properties.requires_grad,
37+
self.dtype,
38+
self.layout,
39+
self.requires_grad,
6340
mem_format_encoding,
64-
self.tensor_properties.pin_memory,
41+
self.pin_memory,
6542
)
6643

6744
def __setstate__(
6845
self,
6946
state,
7047
):
71-
(self.shards_metadata, self.size, dtype, layout, requires_grad, mem_format_encoding, pin_memory) = state
48+
(self.dtype, self.layout, self.requires_grad, mem_format_encoding, self.pin_memory) = state
7249

7350
if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
7451
memory_format = torch.contiguous_format
@@ -79,6 +56,18 @@ def __setstate__(
7956
else:
8057
raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}')
8158

82-
self.tensor_properties = TensorProperties(
83-
dtype=dtype, layout=layout, requires_grad=requires_grad,
84-
memory_format=memory_format, pin_memory=pin_memory, )
59+
self.memory_format = memory_format
60+
61+
@dataclass
62+
class ShardedTensorMetadata(object):
63+
"""
64+
Represents metadata for :class:`ShardedTensor`
65+
"""
66+
67+
# Metadata about each shard of the Tensor
68+
shards_metadata: List[ShardMetadata] = field(default_factory=list)
69+
70+
# Size of each dim of the overall Tensor.
71+
size: torch.Size = field(default=torch.Size([]))
72+
73+
tensor_properties: TensorProperties = field(default=TensorProperties())

0 commit comments

Comments
 (0)