Skip to content

Commit ba3ac9f

Browse files
authored
updating flatten/unflatten functions (#3282)
1 parent 9266734 commit ba3ac9f

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def test_safetensors(self, config, act_pre_scale=False):
6666

6767
with tempfile.NamedTemporaryFile() as f:
6868
tensors_data_dict, metadata = flatten_tensor_state_dict(model.state_dict())
69+
70+
for key in tensors_data_dict.keys():
71+
assert key.startswith("0._weight_") or key.startswith("0.bias"), (
72+
f"Unexpected key format: {key}"
73+
)
74+
6975
save_file(tensors_data_dict, f.name, metadata=metadata)
7076
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
7177
reconstructed_dict = unflatten_tensor_state_dict(

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def unflatten_tensor_state_dict(
2424
2525
For example, given a previously flattened tensors_data_dict and metadata:
2626
tensors_data_dict = {
27-
'0.weight:qdata': torch.Tensor(...),
28-
'0.weight:scale': torch.Tensor(...),
29-
'0.bias:_data': torch.Tensor(...),
27+
'0._weight_qdata': torch.Tensor(...),
28+
'0._weight_scale': torch.Tensor(...),
29+
'0.bias': torch.Tensor(...),
3030
}
3131
metadata = {
3232
'0.weight': {
@@ -53,7 +53,7 @@ def unflatten_tensor_state_dict(
5353
}
5454
5555
Args:
56-
tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+
tensors_data_dict: a dictionary from "{tensor_name}_{tensor_data_attribute_name}" to flattened torch.Tensor data for tensor subclass instance
5757
metadata: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
5858
5959
Returns:
@@ -68,23 +68,30 @@ def unflatten_tensor_state_dict(
6868
result = {}
6969

7070
for tensor_name in tensor_names:
71+
module_fqn, weight_name = tensor_name.rsplit(".", 1)
72+
73+
prefix = f"{module_fqn}._{weight_name}_"
7174
tensor_tensors = {}
7275
for key, value in combined_data.items():
73-
if key.startswith(f"{tensor_name}:"):
76+
if key.startswith(prefix):
7477
# Remove the prefix
75-
tensor_tensors[key[len(tensor_name) + 1 :]] = value
78+
tensor_tensors[key[len(prefix) :]] = value
7679

7780
tensor_metadata = json.loads(metadata.get(tensor_name))
7881
tensor_type = tensor_metadata.get("_type")
7982

8083
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
84+
if not tensor_tensors:
85+
# we allow the option of loading in state_dict info for a single tensor
86+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
87+
# in a future call
88+
continue
8189
tensor_metadata["_data"].update(tensor_tensors)
8290
result[tensor_name] = object_from_dict(tensor_metadata)
8391
elif tensor_type == torch.Tensor.__name__:
84-
result[tensor_name] = tensor_tensors["_data"]
92+
result[tensor_name] = tensors_data_dict[tensor_name]
8593
else:
8694
raise ValueError(f"Unsupported tensor type: {tensor_type}")
87-
8895
return result
8996

9097

@@ -108,9 +115,9 @@ def flatten_tensor_state_dict(
108115
109116
We flatten this to:
110117
tensors_data = {
111-
'0.weight:qdata': torch.Tensor(...),
112-
'0.weight:scale': torch.Tensor(...),
113-
'0.bias:_data': torch.Tensor(...),
118+
'0._weight_qdata': torch.Tensor(...),
119+
'0._weight_scale': torch.Tensor(...),
120+
'0.bias': torch.Tensor(...),
114121
}
115122
metadata = {
116123
'0.weight': {
@@ -152,22 +159,23 @@ def flatten_tensor_state_dict(
152159
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
153160

154161
tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder)
162+
163+
# Clone tensors to avoid memory sharing issues
164+
tensors_dict_to_save = {
165+
f"{tensor_name.rsplit('.', 1)[0]}._{tensor_name.rsplit('.', 1)[1]}_{key}": (
166+
value.detach().clone() if isinstance(value, torch.Tensor) else value
167+
)
168+
for key, value in tensor_dict.items()
169+
}
170+
155171
elif type(tensor) is torch.Tensor:
156-
tensor_dict = {"_data": tensor}
157172
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
173+
tensors_dict_to_save = {tensor_name: tensor}
158174
else:
159175
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
160176

161-
# Clone tensors to avoid memory sharing issues
162-
prefixed_tensors_dict = {
163-
f"{tensor_name}:{key}": (
164-
value.detach().clone() if isinstance(value, torch.Tensor) else value
165-
)
166-
for key, value in tensor_dict.items()
167-
}
168-
169177
metadata[tensor_name] = tensor_metadata
170-
tensors_data_dict.update(prefixed_tensors_dict)
178+
tensors_data_dict.update(tensors_dict_to_save)
171179

172180
metadata["tensor_names"] = json.dumps(list(tensors_dict.keys()))
173181
return tensors_data_dict, metadata

0 commit comments

Comments
 (0)