Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 83 additions & 31 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Module::Module(
const LoadMode load_mode,
std::unique_ptr<runtime::EventTracer> event_tracer,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
bool share_memory_arenas)
: file_path_(file_path),
load_mode_(load_mode),
memory_allocator_(
Expand All @@ -89,7 +90,8 @@ Module::Module(
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)) {
event_tracer_(std::move(event_tracer)),
share_memory_arenas_(share_memory_arenas) {
runtime::runtime_init();
}

Expand All @@ -99,7 +101,8 @@ Module::Module(
const LoadMode load_mode,
std::unique_ptr<runtime::EventTracer> event_tracer,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
bool share_memory_arenas)
: file_path_(file_path),
load_mode_(load_mode),
memory_allocator_(
Expand All @@ -108,7 +111,8 @@ Module::Module(
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)) {
event_tracer_(std::move(event_tracer)),
share_memory_arenas_(share_memory_arenas) {
if (!data_map_path.empty()) {
data_files_.push_back(data_map_path);
}
Expand All @@ -121,7 +125,8 @@ Module::Module(
const LoadMode load_mode,
std::unique_ptr<runtime::EventTracer> event_tracer,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
bool share_memory_arenas)
: file_path_(file_path),
data_files_(std::move(data_files)),
load_mode_(load_mode),
Expand All @@ -131,7 +136,8 @@ Module::Module(
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)) {
event_tracer_(std::move(event_tracer)),
share_memory_arenas_(share_memory_arenas) {
runtime::runtime_init();
}

Expand All @@ -140,15 +146,17 @@ Module::Module(
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
std::unique_ptr<runtime::EventTracer> event_tracer,
std::unique_ptr<runtime::DataLoader> data_map_loader)
std::unique_ptr<runtime::DataLoader> data_map_loader,
bool share_memory_arenas)
: data_loader_(std::move(data_loader)),
memory_allocator_(
memory_allocator ? std::move(memory_allocator)
: std::make_unique<MallocMemoryAllocator>()),
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)) {
event_tracer_(std::move(event_tracer)),
share_memory_arenas_(share_memory_arenas) {
if (data_map_loader) {
data_map_loaders_.push_back(std::move(data_map_loader));
}
Expand All @@ -160,15 +168,17 @@ Module::Module(
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
std::unique_ptr<runtime::EventTracer> event_tracer,
std::unique_ptr<runtime::DataLoader> data_map_loader)
std::unique_ptr<runtime::DataLoader> data_map_loader,
bool share_memory_arenas)
: program_(std::move(program)),
memory_allocator_(
memory_allocator ? std::move(memory_allocator)
: std::make_unique<MallocMemoryAllocator>()),
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)) {
event_tracer_(std::move(event_tracer)),
share_memory_arenas_(share_memory_arenas) {
if (data_map_loader) {
data_map_loaders_.push_back(std::move(data_map_loader));
}
Expand Down Expand Up @@ -253,6 +263,56 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
return result;
}

std::shared_ptr<Module::PlannedMemory> Module::make_planned_memory(
const std::vector<size_t>& buffer_sizes) {
auto planned = std::make_shared<PlannedMemory>();
planned->planned_buffers.reserve(buffer_sizes.size());
planned->planned_spans.reserve(buffer_sizes.size());
for (size_t size : buffer_sizes) {
planned->planned_buffers.emplace_back(size);
planned->planned_spans.emplace_back(
planned->planned_buffers.back().data(), size);
}
planned->planned_memory =
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
planned->planned_spans.data(), planned->planned_spans.size()));
return planned;
}

runtime::Result<std::vector<size_t>> Module::get_mem_planned_buffer_sizes(
const std::string& method_name) {
auto meta_res = program_->method_meta(method_name.c_str());
ET_CHECK_OK_OR_RETURN_ERROR(meta_res.error());
auto meta = meta_res.get();
std::vector<size_t> sizes;
sizes.reserve(meta.num_memory_planned_buffers());
for (size_t i = 0; i < meta.num_memory_planned_buffers(); i++) {
auto size = meta.memory_planned_buffer_size(i);
ET_CHECK_OK_OR_RETURN_ERROR(size.error());
sizes.push_back(size.get());
}
return sizes;
}

runtime::Result<std::vector<size_t>>
Module::get_max_mem_planned_buffer_sizes() {
std::vector<size_t> result;
auto method_names_res = method_names();
ET_CHECK_OK_OR_RETURN_ERROR(method_names_res.error());
for (const auto& name : method_names_res.get()) {
auto sizes_res = get_mem_planned_buffer_sizes(name);
ET_CHECK_OK_OR_RETURN_ERROR(sizes_res.error());
auto& sizes = sizes_res.get();
if (sizes.size() > result.size()) {
result.resize(sizes.size(), 0);
}
for (size_t i = 0; i < sizes.size(); i++) {
result[i] = std::max(result[i], sizes[i]);
}
}
return result;
}

runtime::Error Module::load_method(
const std::string& method_name,
runtime::HierarchicalAllocator* planned_memory,
Expand All @@ -263,29 +323,21 @@ runtime::Error Module::load_method(
MethodHolder method_holder;

if (!planned_memory) {
auto method_metadata_result = program_->method_meta(method_name.c_str());
if (!method_metadata_result.ok()) {
return method_metadata_result.error();
}
const auto method_metadata = std::move(*method_metadata_result);
const auto planned_buffers_count =
method_metadata.num_memory_planned_buffers();
method_holder.planned_buffers.reserve(planned_buffers_count);
method_holder.planned_spans.reserve(planned_buffers_count);

for (auto index = 0; index < planned_buffers_count; ++index) {
const auto buffer_size =
method_metadata.memory_planned_buffer_size(index).get();
method_holder.planned_buffers.emplace_back(buffer_size);
method_holder.planned_spans.emplace_back(
method_holder.planned_buffers.back().data(), buffer_size);
if (!share_memory_arenas_) {
auto sizes_res = get_mem_planned_buffer_sizes(method_name);
ET_CHECK_OK_OR_RETURN_ERROR(sizes_res.error());
method_holder.planned_memory = make_planned_memory(sizes_res.get());
} else {
if (!shared_planned_memory_) {
auto max_res = get_max_mem_planned_buffer_sizes();
ET_CHECK_OK_OR_RETURN_ERROR(max_res.error());
shared_planned_memory_ = make_planned_memory(max_res.get());
}
method_holder.planned_memory = shared_planned_memory_;
Comment on lines +331 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does mean that methods cannot be invoked in parallel as they may overwrite each others' arena. Document that explicitly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if AOT memory planning accounted for it then should we not assert that the planned memory has the same mem_id or something for them to be shareable?

Copy link
Contributor Author

@lucylq lucylq Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel

This does mean that methods cannot be invoked in parallel as they may overwrite each others' arena. Document that explicitly

Yeah that's right. This isn't a currently supported feature (unsafe to run multiple methods in separate threads within a Module, even without sharing), but sharing makes it harder to support. I'll add an extra comment in module.h

Also if AOT memory planning accounted for it then should we not assert that the planned memory has the same mem_id or something for them to be shareable?

Initially I was thinking we only share physical buffers when mem_id=2 (buffer marked as shareable AoT).

However, @JacobSzwejbka's original diff D82329513, shares both and I don't see any issues with that besides output tensors needing to be copied into permanent memory after running each method, which I think is OK (users have to do this when running the same method twice already).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was just wondering if we should enforce that only buffers that are on the same mem_id can share memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was wondering that too. It seems fine to share all of it, and we get some memory savings as well, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think mem_id has specific meaning so I think asserting for that is good.

}
method_holder.planned_memory =
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
method_holder.planned_spans.data(),
method_holder.planned_spans.size()));
planned_memory = method_holder.planned_memory.get();
planned_memory = method_holder.planned_memory->planned_memory.get();
}

method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
memory_allocator_.get(), planned_memory, temp_allocator_.get());
auto res_method = program_->load_method(
Expand Down
46 changes: 40 additions & 6 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,24 @@ class Module {
* @param[in] file_path The path to the ExecuTorch program file to load.
* @param[in] load_mode The loading mode to use.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
* @param[in] share_memory_arenas When true, all methods loaded by this Module
* share a single set of memory-planned buffers, sized to the max across all
* methods. This is required for models exported with
* share_mutable_buffers=True, where methods access shared mutable state
* (e.g., KV cache). When enabled, outputs from one method may be invalidated
* by executing another method, since their output tensors can alias the same
* underlying buffer. Consume or copy outputs before calling execute again.
* NOTE: It is unsafe to execute methods in parallel when True as methods can
* write to the same memory. It's most likely unsafe when False, method does
* not provide thread safety guarantees.
*/
explicit Module(
const std::string& file_path,
const LoadMode load_mode = LoadMode::File,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
bool share_memory_arenas = false);

/**
* Constructs an instance by loading a program from a file with specified
Expand All @@ -75,14 +86,17 @@ class Module {
* @param[in] data_map_path The path to a .ptd file.
* @param[in] load_mode The loading mode to use.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
* @param[in] share_memory_arenas When true, all methods loaded by this Module
* share a single set of memory-planned buffers.
*/
explicit Module(
const std::string& file_path,
const std::string& data_map_path,
const LoadMode load_mode = LoadMode::File,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
bool share_memory_arenas = false);

/**
* Constructs an instance by loading a program from a file with specified
Expand All @@ -92,14 +106,17 @@ class Module {
* @param[in] data_files The path to one or more .ptd file/s.
* @param[in] load_mode The loading mode to use.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
* @param[in] share_memory_arenas When true, all methods loaded by this Module
* share a single set of memory-planned buffers.
*/
explicit Module(
const std::string& file_path,
std::vector<std::string> data_files,
const LoadMode load_mode = LoadMode::File,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
bool share_memory_arenas = false);

/**
* Constructs an instance with the provided data loader and memory allocator.
Expand All @@ -110,13 +127,16 @@ class Module {
* temporary data during kernel or delegate execution.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
* @param[in] data_map_loader A DataLoader used for loading external weights.
* @param[in] share_memory_arenas When true, all methods loaded by this Module
* share a single set of memory-planned buffers.
*/
explicit Module(
std::unique_ptr<runtime::DataLoader> data_loader,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr,
bool share_memory_arenas = false);

/**
* Constructs an instance using an existing shared program.
Expand All @@ -128,13 +148,16 @@ class Module {
* temporary data.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
* @param[in] data_map_loader A DataLoader used for loading external weights.
* @param[in] share_memory_arenas When true, all methods loaded by this Module
* share a single set of memory-planned buffers.
*/
explicit Module(
std::shared_ptr<Program> program,
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr,
bool share_memory_arenas = false);

Module(const Module&) = delete;
Module& operator=(const Module&) = delete;
Expand Down Expand Up @@ -630,10 +653,19 @@ class Module {
}

private:
struct MethodHolder {
struct PlannedMemory {
std::vector<std::vector<uint8_t>> planned_buffers;
std::vector<runtime::Span<uint8_t>> planned_spans;
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
};
std::shared_ptr<PlannedMemory> make_planned_memory(
const std::vector<size_t>& buffer_sizes);
runtime::Result<std::vector<size_t>> get_mem_planned_buffer_sizes(
const std::string& method_name);
runtime::Result<std::vector<size_t>> get_max_mem_planned_buffer_sizes();

struct MethodHolder {
std::shared_ptr<PlannedMemory> planned_memory;
std::unique_ptr<runtime::MemoryManager> memory_manager;
std::unique_ptr<Method> method;
};
Expand All @@ -649,7 +681,9 @@ class Module {
std::vector<std::unique_ptr<runtime::DataLoader>> data_map_loaders_;
std::vector<std::unique_ptr<NamedDataMap>> named_data_maps_;
std::unique_ptr<NamedDataMap> merged_data_map_;
std::shared_ptr<PlannedMemory> shared_planned_memory_;
ET_DEPRECATED std::vector<uint8_t> debug_buffer_;
bool share_memory_arenas_;

protected:
std::unordered_map<std::string, MethodHolder> methods_;
Expand Down
5 changes: 4 additions & 1 deletion extension/module/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ add_custom_command(
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSharedState.pte"
COMMAND ${PYTHON_EXECUTABLE} -m test.models.export_program --modules
"ModuleAdd" --outdir "${CMAKE_CURRENT_BINARY_DIR}"
"ModuleAdd,ModuleSharedState" --outdir "${CMAKE_CURRENT_BINARY_DIR}"
COMMAND
${PYTHON_EXECUTABLE} -m test.models.export_program --modules
"ModuleAddMul,ModuleLinear" --external-constants --outdir
Expand All @@ -41,6 +42,7 @@ add_custom_target(
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSharedState.pte"
)

set(test_env
Expand All @@ -49,6 +51,7 @@ set(test_env
"ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
"ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
"ET_MODULE_SHARED_STATE=${CMAKE_CURRENT_BINARY_DIR}/ModuleSharedState.pte"
)

et_cxx_test(
Expand Down
Loading
Loading