@@ -24,9 +24,9 @@ def unflatten_tensor_state_dict(
2424
2525 For example, given a previously flattened tensors_data_dict and metadata:
2626 tensors_data_dict = {
27- '0.weight:qdata ': torch.Tensor(...),
28- '0.weight:scale ': torch.Tensor(...),
29- '0.bias:_data ': torch.Tensor(...),
27+ '0._weight_qdata ': torch.Tensor(...),
28+ '0._weight_scale ': torch.Tensor(...),
29+ '0.bias': torch.Tensor(...),
3030 }
3131 metadata = {
3232 '0.weight': {
@@ -53,7 +53,7 @@ def unflatten_tensor_state_dict(
5353 }
5454
5555 Args:
56- tensors_data_dict: a dictionary from "tensor_name: tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+ tensors_data_dict: a dictionary from "{ tensor_name}_{ tensor_data_attribute_name} " to flattened torch.Tensor data for tensor subclass instance
5757 metadata: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
5858
5959 Returns:
@@ -68,23 +68,30 @@ def unflatten_tensor_state_dict(
6868 result = {}
6969
7070 for tensor_name in tensor_names :
71+ module_fqn , weight_name = tensor_name .rsplit ("." , 1 )
72+
73+ prefix = f"{ module_fqn } ._{ weight_name } _"
7174 tensor_tensors = {}
7275 for key , value in combined_data .items ():
73- if key .startswith (f" { tensor_name } :" ):
76+ if key .startswith (prefix ):
7477 # Remove the prefix
75- tensor_tensors [key [len (tensor_name ) + 1 :]] = value
78+ tensor_tensors [key [len (prefix ) :]] = value
7679
7780 tensor_metadata = json .loads (metadata .get (tensor_name ))
7881 tensor_type = tensor_metadata .get ("_type" )
7982
8083 if tensor_type in ALLOWED_TENSORS_SUBCLASSES :
84+ if not tensor_tensors :
85+ # we allow the option of loading in state_dict info for a single tensor
86+ # if tensor state dict info is not loaded in yet, we wait for it to be provided
87+ # in a future call
88+ continue
8189 tensor_metadata ["_data" ].update (tensor_tensors )
8290 result [tensor_name ] = object_from_dict (tensor_metadata )
8391 elif tensor_type == torch .Tensor .__name__ :
84- result [tensor_name ] = tensor_tensors [ "_data" ]
92+ result [tensor_name ] = tensors_data_dict [ tensor_name ]
8593 else :
8694 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
87-
8895 return result
8996
9097
@@ -108,9 +115,9 @@ def flatten_tensor_state_dict(
108115
109116 We flatten this to:
110117 tensors_data = {
111- '0.weight:qdata ': torch.Tensor(...),
112- '0.weight:scale ': torch.Tensor(...),
113- '0.bias:_data ': torch.Tensor(...),
118+ '0._weight_qdata ': torch.Tensor(...),
119+ '0._weight_scale ': torch.Tensor(...),
120+ '0.bias': torch.Tensor(...),
114121 }
115122 metadata = {
116123 '0.weight': {
@@ -152,22 +159,23 @@ def flatten_tensor_state_dict(
152159 tensor_dict [tensor_data_name ] = getattr (tensor , tensor_data_name )
153160
154161 tensor_metadata = json .dumps (tensor , cls = TensorSubclassAttributeJSONEncoder )
162+
163+ # Clone tensors to avoid memory sharing issues
164+ tensors_dict_to_save = {
165+ f"{ tensor_name .rsplit ('.' , 1 )[0 ]} ._{ tensor_name .rsplit ('.' , 1 )[1 ]} _{ key } " : (
166+ value .detach ().clone () if isinstance (value , torch .Tensor ) else value
167+ )
168+ for key , value in tensor_dict .items ()
169+ }
170+
155171 elif type (tensor ) is torch .Tensor :
156- tensor_dict = {"_data" : tensor }
157172 tensor_metadata = json .dumps ({"_type" : torch .Tensor .__name__ })
173+ tensors_dict_to_save = {tensor_name : tensor }
158174 else :
159175 raise ValueError (f"Unsupported tensor type: { type (tensor )} " )
160176
161- # Clone tensors to avoid memory sharing issues
162- prefixed_tensors_dict = {
163- f"{ tensor_name } :{ key } " : (
164- value .detach ().clone () if isinstance (value , torch .Tensor ) else value
165- )
166- for key , value in tensor_dict .items ()
167- }
168-
169177 metadata [tensor_name ] = tensor_metadata
170- tensors_data_dict .update (prefixed_tensors_dict )
178+ tensors_data_dict .update (tensors_dict_to_save )
171179
172180 metadata ["tensor_names" ] = json .dumps (list (tensors_dict .keys ()))
173181 return tensors_data_dict , metadata
0 commit comments