5
5
import torch
6
6
from torch .distributed ._shard .metadata import ShardMetadata
7
7
8
-
9
8
class MEM_FORMAT_ENCODING (Enum ):
10
9
TORCH_CONTIGUOUS_FORMAT = 0
11
10
TORCH_CHANNELS_LAST = 1
@@ -22,28 +21,9 @@ class TensorProperties(object):
22
21
memory_format : torch .memory_format = field (default = torch .contiguous_format )
23
22
pin_memory : bool = False
24
23
25
- @dataclass
26
- class ShardedTensorMetadata (object ):
27
- """
28
- Represents metadata for :class:`ShardedTensor`
29
- """
30
-
31
- # Metadata about each shard of the Tensor
32
- shards_metadata : List [ShardMetadata ] = field (default_factory = list )
33
-
34
- # Size of each dim of the overall Tensor.
35
- size : torch .Size = field (default = torch .Size ([]))
36
-
37
- tensor_properties : TensorProperties = field (
38
- default = TensorProperties (dtype = torch .get_default_dtype (),
39
- layout = torch .strided ,
40
- requires_grad = False ,
41
- memory_format = torch .contiguous_format ,
42
- pin_memory = False ))
43
-
44
24
def __getstate__ (self ):
45
25
# Since torch.memory_format cannot be pickled!
46
- memory_format = self .tensor_properties . memory_format
26
+ memory_format = self .memory_format
47
27
if memory_format == torch .contiguous_format :
48
28
mem_format_encoding = MEM_FORMAT_ENCODING .TORCH_CONTIGUOUS_FORMAT
49
29
elif memory_format == torch .channels_last :
@@ -53,22 +33,19 @@ def __getstate__(self):
53
33
else :
54
34
raise RuntimeError (f'Invalid torch.memory_format: { memory_format } ' )
55
35
56
- # Keep old serialization to ensure backward compatibility
57
36
return (
58
- self .shards_metadata ,
59
- self .size ,
60
- self .tensor_properties .dtype ,
61
- self .tensor_properties .layout ,
62
- self .tensor_properties .requires_grad ,
37
+ self .dtype ,
38
+ self .layout ,
39
+ self .requires_grad ,
63
40
mem_format_encoding ,
64
- self .tensor_properties . pin_memory ,
41
+ self .pin_memory ,
65
42
)
66
43
67
44
def __setstate__ (
68
45
self ,
69
46
state ,
70
47
):
71
- (self .shards_metadata , self .size , dtype , layout , requires_grad , mem_format_encoding , pin_memory ) = state
48
+ (self .dtype , self .layout , self . requires_grad , mem_format_encoding , self . pin_memory ) = state
72
49
73
50
if mem_format_encoding == MEM_FORMAT_ENCODING .TORCH_CONTIGUOUS_FORMAT :
74
51
memory_format = torch .contiguous_format
@@ -79,6 +56,18 @@ def __setstate__(
79
56
else :
80
57
raise RuntimeError (f'Invalid torch.memory_format encoding: { mem_format_encoding } ' )
81
58
82
- self .tensor_properties = TensorProperties (
83
- dtype = dtype , layout = layout , requires_grad = requires_grad ,
84
- memory_format = memory_format , pin_memory = pin_memory , )
59
+ self .memory_format = memory_format
60
+
61
+ @dataclass
62
+ class ShardedTensorMetadata (object ):
63
+ """
64
+ Represents metadata for :class:`ShardedTensor`
65
+ """
66
+
67
+ # Metadata about each shard of the Tensor
68
+ shards_metadata : List [ShardMetadata ] = field (default_factory = list )
69
+
70
+ # Size of each dim of the overall Tensor.
71
+ size : torch .Size = field (default = torch .Size ([]))
72
+
73
+ tensor_properties : TensorProperties = field (default = TensorProperties ())
0 commit comments