Skip to content

Commit 546ec43

Browse files
authored
[ET-VK][AOT] Serialize constant tensors via NamedDataMap
Differential Revision: D80460034 Pull Request resolved: #13473
1 parent 150afe4 commit 546ec43

File tree

3 files changed

+51
-5
lines changed

3 files changed

+51
-5
lines changed

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
8+
import hashlib
79
import logging
810
import operator
911
from types import NoneType
@@ -25,6 +27,7 @@
2527
is_symint_node,
2628
TensorRepr,
2729
)
30+
from executorch.exir._serialize._named_data_store import NamedDataStore
2831
from executorch.exir.backend.utils import DelegateMappingBuilder
2932

3033
from executorch.exir.tensor import TensorSpec
@@ -56,6 +59,7 @@ def __init__(
5659
self.input_ids = []
5760
self.output_ids = []
5861
self.const_tensors = []
62+
self.named_data_store = NamedDataStore()
5963

6064
# Mapping from Node to VkValue id
6165
self.node_to_value_ids = {}
@@ -129,8 +133,36 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
129133
def maybe_add_constant_tensor(self, node: Node) -> int:
130134
constant_id = -1
131135
if is_param_node(self.program, node):
132-
constant_id = len(self.const_tensors)
133-
self.const_tensors.append(self.get_param_tensor(node))
136+
tensor = self.get_param_tensor(node)
137+
138+
# Serialize tensor data to bytes
139+
tensor = tensor.contiguous()
140+
size = tensor.untyped_storage().nbytes()
141+
142+
if size > 0:
143+
array_type = ctypes.c_char * size
144+
array = ctypes.cast(
145+
tensor.untyped_storage().data_ptr(),
146+
ctypes.POINTER(array_type),
147+
).contents
148+
149+
# Generate SHA256 hash as the named key
150+
tensor_bytes = bytes(array)
151+
sha256_hash = hashlib.sha256(tensor_bytes)
152+
named_key = sha256_hash.hexdigest()
153+
154+
# Add to named data store with 16-byte alignment (matching XNNPACK)
155+
self.named_data_store.add_named_data(
156+
named_key, tensor_bytes, alignment=16
157+
)
158+
159+
# Create VkBytes entry with named_key and set offset to indicate named data usage
160+
constant_id = len(self.const_tensors)
161+
self.const_tensors.append((named_key, size))
162+
else:
163+
# Handle empty tensors
164+
constant_id = len(self.const_tensors)
165+
self.const_tensors.append(None)
134166

135167
return constant_id
136168

backends/vulkan/serialization/vulkan_graph_serialize.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,21 @@ def serialize_constant_tensors(
191191

192192
current_offset = len(raw_bytes)
193193
for tensor in const_tensors:
194-
if tensor.numel() == 0:
194+
# The tensor data is stored in the named data map
195+
if isinstance(tensor, tuple):
196+
named_key, size = tensor
197+
vk_graph.constants.append(
198+
VkBytes(
199+
offset=18446744073709551615, # UINT64_MAX to indicate named data
200+
length=size,
201+
named_key=named_key,
202+
)
203+
)
204+
elif tensor is None or (
205+
isinstance(tensor, torch.Tensor) and tensor.numel() == 0
206+
):
195207
vk_graph.constants.append(VkBytes(current_offset, 0))
196-
continue
197-
else:
208+
elif isinstance(tensor, torch.Tensor):
198209
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
199210
array = ctypes.cast(
200211
tensor.untyped_storage().data_ptr(),
@@ -208,6 +219,8 @@ def serialize_constant_tensors(
208219

209220
vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
210221
current_offset += aligned_size(len(tensor_bytes))
222+
else:
223+
raise ValueError(f"Unsupported constant tensor type: {type(tensor)}")
211224

212225

213226
def serialize_custom_shaders(

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,5 @@ def preprocess( # noqa: C901
229229
vk_graph, graph_builder.const_tensors, []
230230
),
231231
debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(),
232+
data_store_output=graph_builder.named_data_store.get_named_data_store_output(),
232233
)

0 commit comments

Comments
 (0)