|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import ctypes |
| 8 | +import hashlib |
7 | 9 | import logging
|
8 | 10 | import operator
|
9 | 11 | from types import NoneType
|
|
25 | 27 | is_symint_node,
|
26 | 28 | TensorRepr,
|
27 | 29 | )
|
| 30 | +from executorch.exir._serialize._named_data_store import NamedDataStore |
28 | 31 | from executorch.exir.backend.utils import DelegateMappingBuilder
|
29 | 32 |
|
30 | 33 | from executorch.exir.tensor import TensorSpec
|
@@ -56,6 +59,7 @@ def __init__(
|
56 | 59 | self.input_ids = []
|
57 | 60 | self.output_ids = []
|
58 | 61 | self.const_tensors = []
|
| 62 | + self.named_data_store = NamedDataStore() |
59 | 63 |
|
60 | 64 | # Mapping from Node to VkValue id
|
61 | 65 | self.node_to_value_ids = {}
|
@@ -129,8 +133,36 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
|
129 | 133 | def maybe_add_constant_tensor(self, node: Node) -> int:
|
130 | 134 | constant_id = -1
|
131 | 135 | if is_param_node(self.program, node):
|
132 |
| - constant_id = len(self.const_tensors) |
133 |
| - self.const_tensors.append(self.get_param_tensor(node)) |
| 136 | + tensor = self.get_param_tensor(node) |
| 137 | + |
| 138 | + # Serialize tensor data to bytes |
| 139 | + tensor = tensor.contiguous() |
| 140 | + size = tensor.untyped_storage().nbytes() |
| 141 | + |
| 142 | + if size > 0: |
| 143 | + array_type = ctypes.c_char * size |
| 144 | + array = ctypes.cast( |
| 145 | + tensor.untyped_storage().data_ptr(), |
| 146 | + ctypes.POINTER(array_type), |
| 147 | + ).contents |
| 148 | + |
| 149 | + # Generate SHA256 hash as the named key |
| 150 | + tensor_bytes = bytes(array) |
| 151 | + sha256_hash = hashlib.sha256(tensor_bytes) |
| 152 | + named_key = sha256_hash.hexdigest() |
| 153 | + |
| 154 | + # Add to named data store with 16-byte alignment (matching XNNPACK) |
| 155 | + self.named_data_store.add_named_data( |
| 156 | + named_key, tensor_bytes, alignment=16 |
| 157 | + ) |
| 158 | + |
| 159 | + # Create VkBytes entry with named_key and set offset to indicate named data usage |
| 160 | + constant_id = len(self.const_tensors) |
| 161 | + self.const_tensors.append((named_key, size)) |
| 162 | + else: |
| 163 | + # Handle empty tensors |
| 164 | + constant_id = len(self.const_tensors) |
| 165 | + self.const_tensors.append(None) |
134 | 166 |
|
135 | 167 | return constant_id
|
136 | 168 |
|
|
0 commit comments