2222#include < executorch/runtime/core/event_tracer_hooks_delegate.h>
2323#endif // ET_EVENT_TRACER_ENABLED
2424#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
25+ #include < executorch/runtime/core/named_data_map.h>
2526#include < executorch/runtime/platform/compiler.h>
2627#include < executorch/runtime/platform/profiler.h>
2728
@@ -47,6 +48,7 @@ using executorch::runtime::Error;
4748using executorch::runtime::EValue;
4849using executorch::runtime::FreeableBuffer;
4950using executorch::runtime::kTensorDimensionLimit ;
51+ using executorch::runtime::NamedDataMap;
5052using executorch::runtime::Result;
5153using executorch::runtime::Span;
5254
@@ -66,14 +68,6 @@ using BytesVector =
6668 const flatbuffers::Vector<flatbuffers::Offset<vkgraph::VkBytes>>*;
6769using UIntVector = const flatbuffers::Vector<uint32_t >*;
6870
69- const uint8_t * get_constant_data_ptr (
70- VkGraphPtr flatbuffer_graph,
71- const int32_t buffer_idx,
72- const uint8_t * constant_data) {
73- VkBytesPtr constant_bytes = flatbuffer_graph->constants ()->Get (buffer_idx);
74- return constant_data + constant_bytes->offset ();
75- }
76-
7771vkapi::ScalarType get_scalar_type (const vkgraph::VkDataType& vk_datatype) {
7872 switch (vk_datatype) {
7973 case vkgraph::VkDataType::BOOL:
@@ -166,17 +160,22 @@ class GraphBuilder {
166160 ComputeGraph* compute_graph_;
167161 VkGraphPtr flatbuffer_;
168162 const uint8_t * constant_data_;
163+ const NamedDataMap* named_data_map_;
164+ std::vector<FreeableBuffer> loaded_buffers_from_map_;
169165
170166 std::vector<ValueRef> ref_mapping_;
171167
172168 public:
173169 explicit GraphBuilder (
174170 ComputeGraph* compute_graph,
175171 VkGraphPtr flatbuffer,
176- const uint8_t * constant_data)
172+ const uint8_t * constant_data,
173+ const NamedDataMap* named_data_map)
177174 : compute_graph_(compute_graph),
178175 flatbuffer_(flatbuffer),
179176 constant_data_(constant_data),
177+ named_data_map_(named_data_map),
178+ loaded_buffers_from_map_(),
180179 ref_mapping_() {}
181180
182181 void resize (uint32_t size) {
@@ -212,10 +211,27 @@ class GraphBuilder {
212211
213212 ValueRef ref;
214213 if (tensor_fb->constant_id () >= 0 ) {
215- const uint8_t * tensor_data = get_constant_data_ptr (
216- flatbuffer_, tensor_fb->constant_id (), constant_data_ );
214+ VkBytesPtr constant_bytes =
215+ flatbuffer_-> constants ()-> Get ( tensor_fb->constant_id ());
217216
218- ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
217+ if (constant_bytes->named_key () != nullptr &&
218+ constant_bytes->offset () == UINT64_MAX &&
219+ named_data_map_ != nullptr ) {
220+ const std::string& data_name = constant_bytes->named_key ()->str ();
221+ Result<FreeableBuffer> buffer =
222+ named_data_map_->get_data (data_name.c_str ());
223+
224+ VK_CHECK_COND (
225+ buffer.ok (),
226+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
227+ data_name.c_str (),
228+ static_cast <uint32_t >(buffer.error ()));
229+ ref = compute_graph_->add_tensorref (
230+ dims_vector, dtype, std::move (buffer.get ()));
231+ } else {
232+ const uint8_t * tensor_data = constant_data_ + constant_bytes->offset ();
233+ ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
234+ }
219235 } else {
220236 ref = compute_graph_->add_tensor (
221237 dims_vector,
@@ -479,8 +495,10 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
479495 return true ;
480496 }
481497
482- ET_NODISCARD Error
483- compileModel (const void * buffer_pointer, ComputeGraph* compute_graph) const {
498+ ET_NODISCARD Error compileModel (
499+ const void * buffer_pointer,
500+ ComputeGraph* compute_graph,
501+ const NamedDataMap* named_data_map) const {
484502 Result<VulkanDelegateHeader> header =
485503 VulkanDelegateHeader::parse (buffer_pointer);
486504
@@ -506,7 +524,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
506524
507525 VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph (flatbuffer_data);
508526
509- GraphBuilder builder (compute_graph, flatbuffer_graph, constant_data);
527+ GraphBuilder builder (
528+ compute_graph, flatbuffer_graph, constant_data, named_data_map);
510529
511530 builder.build_graph ();
512531
@@ -532,7 +551,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
532551 graph_config.external_adapter = vkapi::set_and_get_external_adapter ();
533552 new (compute_graph) ComputeGraph (graph_config);
534553
535- Error err = compileModel (processed->data (), compute_graph);
554+ const NamedDataMap* named_data_map = context.get_named_data_map ();
555+ Error err = compileModel (processed->data (), compute_graph, named_data_map);
536556
537557 // This backend does not need its processed data after compiling the
538558 // model.
0 commit comments