Skip to content

Commit 864f2cd

Browse files
Enhance load_method to support optional planned memory allocator
- Updated the load_method signature to accept an optional runtime::HierarchicalAllocator parameter.
1 parent d99970b commit 864f2cd

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

extension/module/module.cpp

+32-21
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
#include <executorch/extension/data_loader/file_data_loader.h>
1212
#include <executorch/extension/data_loader/mmap_data_loader.h>
1313
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
14+
#include <executorch/runtime/core/hierarchical_allocator.h>
1415
#include <executorch/runtime/platform/runtime.h>
16+
#include <memory>
1517

1618
/**
1719
* Unwrap a Result to obtain its value (direct object, not a pointer).
@@ -125,33 +127,42 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125127

126128
runtime::Error Module::load_method(
127129
const std::string& method_name,
128-
torch::executor::EventTracer* event_tracer) {
130+
torch::executor::EventTracer* event_tracer,
131+
runtime::HierarchicalAllocator* planned_memory_allocator) {
129132
if (!is_method_loaded(method_name)) {
130133
ET_CHECK_OK_OR_RETURN_ERROR(load());
131134

132135
MethodHolder method_holder;
133-
const auto method_metadata =
134-
ET_UNWRAP(program_->method_meta(method_name.c_str()));
135-
const auto planned_buffersCount =
136-
method_metadata.num_memory_planned_buffers();
137-
method_holder.planned_buffers.reserve(planned_buffersCount);
138-
method_holder.planned_spans.reserve(planned_buffersCount);
139-
140-
for (auto index = 0; index < planned_buffersCount; ++index) {
141-
const auto buffer_size =
142-
method_metadata.memory_planned_buffer_size(index).get();
143-
method_holder.planned_buffers.emplace_back(buffer_size);
144-
method_holder.planned_spans.emplace_back(
145-
method_holder.planned_buffers.back().data(), buffer_size);
136+
runtime::HierarchicalAllocator* planned_memory = nullptr;
137+
138+
// we were not given a planned memory allocator, so we need to create one:
139+
if (planned_memory_allocator == nullptr) {
140+
const auto method_metadata =
141+
ET_UNWRAP(program_->method_meta(method_name.c_str()));
142+
const auto planned_buffersCount =
143+
method_metadata.num_memory_planned_buffers();
144+
method_holder.planned_buffers.reserve(planned_buffersCount);
145+
method_holder.planned_spans.reserve(planned_buffersCount);
146+
147+
for (auto index = 0; index < planned_buffersCount; ++index) {
148+
const auto buffer_size =
149+
method_metadata.memory_planned_buffer_size(index).get();
150+
method_holder.planned_buffers.emplace_back(buffer_size);
151+
method_holder.planned_spans.emplace_back(
152+
method_holder.planned_buffers.back().data(), buffer_size);
153+
}
154+
method_holder.planned_memory =
155+
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
156+
method_holder.planned_spans.data(),
157+
method_holder.planned_spans.size()));
158+
planned_memory = method_holder.planned_memory.get();
159+
} else {
160+
// we were given a planned memory allocator, so we use it:
161+
planned_memory = planned_memory_allocator;
146162
}
147-
method_holder.planned_memory =
148-
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
149-
method_holder.planned_spans.data(),
150-
method_holder.planned_spans.size()));
163+
151164
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152-
memory_allocator_.get(),
153-
method_holder.planned_memory.get(),
154-
temp_allocator_.get());
165+
memory_allocator_.get(), planned_memory, temp_allocator_.get());
155166
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156167
method_name.c_str(),
157168
method_holder.memory_manager.get(),

extension/module/module.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ class Module {
143143
ET_NODISCARD
144144
runtime::Error load_method(
145145
const std::string& method_name,
146-
torch::executor::EventTracer* event_tracer = nullptr);
146+
torch::executor::EventTracer* event_tracer = nullptr,
147+
runtime::HierarchicalAllocator* planned_memory_allocator = nullptr);
147148

148149
/**
149150
* Load the 'forward' method from the program and set up memory management if

0 commit comments

Comments
 (0)