|
11 | 11 | #include <executorch/extension/data_loader/file_data_loader.h>
|
12 | 12 | #include <executorch/extension/data_loader/mmap_data_loader.h>
|
13 | 13 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
|
| 14 | +#include <executorch/runtime/core/hierarchical_allocator.h> |
14 | 15 | #include <executorch/runtime/platform/runtime.h>
|
| 16 | +#include <memory> |
15 | 17 |
|
16 | 18 | /**
|
17 | 19 | * Unwrap a Result to obtain its value (direct object, not a pointer).
|
@@ -125,33 +127,42 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
|
125 | 127 |
|
126 | 128 | runtime::Error Module::load_method(
|
127 | 129 | const std::string& method_name,
|
128 |
| - torch::executor::EventTracer* event_tracer) { |
| 130 | + torch::executor::EventTracer* event_tracer, |
| 131 | + runtime::HierarchicalAllocator* planned_memory_allocator) { |
129 | 132 | if (!is_method_loaded(method_name)) {
|
130 | 133 | ET_CHECK_OK_OR_RETURN_ERROR(load());
|
131 | 134 |
|
132 | 135 | MethodHolder method_holder;
|
133 |
| - const auto method_metadata = |
134 |
| - ET_UNWRAP(program_->method_meta(method_name.c_str())); |
135 |
| - const auto planned_buffersCount = |
136 |
| - method_metadata.num_memory_planned_buffers(); |
137 |
| - method_holder.planned_buffers.reserve(planned_buffersCount); |
138 |
| - method_holder.planned_spans.reserve(planned_buffersCount); |
139 |
| - |
140 |
| - for (auto index = 0; index < planned_buffersCount; ++index) { |
141 |
| - const auto buffer_size = |
142 |
| - method_metadata.memory_planned_buffer_size(index).get(); |
143 |
| - method_holder.planned_buffers.emplace_back(buffer_size); |
144 |
| - method_holder.planned_spans.emplace_back( |
145 |
| - method_holder.planned_buffers.back().data(), buffer_size); |
| 136 | + runtime::HierarchicalAllocator* planned_memory = nullptr; |
| 137 | + |
| 138 | + // we were not given a planned memory allocator, so we need to create one: |
| 139 | + if (planned_memory_allocator == nullptr) { |
| 140 | + const auto method_metadata = |
| 141 | + ET_UNWRAP(program_->method_meta(method_name.c_str())); |
| 142 | + const auto planned_buffersCount = |
| 143 | + method_metadata.num_memory_planned_buffers(); |
| 144 | + method_holder.planned_buffers.reserve(planned_buffersCount); |
| 145 | + method_holder.planned_spans.reserve(planned_buffersCount); |
| 146 | + |
| 147 | + for (auto index = 0; index < planned_buffersCount; ++index) { |
| 148 | + const auto buffer_size = |
| 149 | + method_metadata.memory_planned_buffer_size(index).get(); |
| 150 | + method_holder.planned_buffers.emplace_back(buffer_size); |
| 151 | + method_holder.planned_spans.emplace_back( |
| 152 | + method_holder.planned_buffers.back().data(), buffer_size); |
| 153 | + } |
| 154 | + method_holder.planned_memory = |
| 155 | + std::make_unique<runtime::HierarchicalAllocator>(runtime::Span( |
| 156 | + method_holder.planned_spans.data(), |
| 157 | + method_holder.planned_spans.size())); |
| 158 | + planned_memory = method_holder.planned_memory.get(); |
| 159 | + } else { |
| 160 | + // we were given a planned memory allocator, so we use it: |
| 161 | + planned_memory = planned_memory_allocator; |
146 | 162 | }
|
147 |
| - method_holder.planned_memory = |
148 |
| - std::make_unique<runtime::HierarchicalAllocator>(runtime::Span( |
149 |
| - method_holder.planned_spans.data(), |
150 |
| - method_holder.planned_spans.size())); |
| 163 | + |
151 | 164 | method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
|
152 |
| - memory_allocator_.get(), |
153 |
| - method_holder.planned_memory.get(), |
154 |
| - temp_allocator_.get()); |
| 165 | + memory_allocator_.get(), planned_memory, temp_allocator_.get()); |
155 | 166 | method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
|
156 | 167 | method_name.c_str(),
|
157 | 168 | method_holder.memory_manager.get(),
|
|
0 commit comments