Skip to content

Pass one NDM to backend init #10669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions extension/flat_tensor/flat_tensor_data_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ bool is_aligned(const void* data) {
return addr % kMinimumAlignment == 0;
}

Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
const char* key,
const flatbuffers::Vector<
flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data) {
// Linear search by name.
if (named_data == nullptr) {
return Error::NotFound;
}
for (int i = 0; i < named_data->size(); i++) {
if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) {
const auto* metadata = named_data->Get(i);
ET_CHECK_OR_RETURN_ERROR(
metadata->segment_index() >= 0,
InvalidExternalData,
"Invalid segment_index %d; malformed PTD file.",
metadata->segment_index());
return metadata;
}
}
return Error::NotFound;
}

Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
const char* key,
const flatbuffers::Vector<
Expand Down Expand Up @@ -109,6 +131,39 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(

ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
const char* key) const {
// TODO(lfq): consolidate named_data and tensors.
// Check named data.
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
get_named_data(key, flat_tensor_->named_data());
if (named_data.ok()) {
size_t segment_index = named_data.get()->segment_index();
ET_CHECK_OR_RETURN_ERROR(
segment_index < flat_tensor_->segments()->size(),
InvalidExternalData,
"Invalid segment_index %zu; malformed PTD file.",
segment_index);

size_t segment_offset =
flat_tensor_->segments()->Get(segment_index)->offset();
size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size();
ET_CHECK_OR_RETURN_ERROR(
segment_offset <
header_.segment_base_offset + header_.segment_data_size,
InvalidExternalData,
"Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64
"; malformed PTD file.",
segment_offset,
header_.segment_base_offset + header_.segment_data_size);
return loader_->load(
/*offset=*/header_.segment_base_offset + segment_offset,
segment_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
}
if (named_data.error() != Error::NotFound) {
return named_data.error();
}

// Check tensors, if named data is not found.
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
get_flat_tensor_metadata(key, flat_tensor_->tensors());
if (!metadata.ok()) {
Expand Down Expand Up @@ -179,16 +234,34 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
return flat_tensor_->tensors()->size();
// TODO(lfq): consolidate named_data and tensors.
if (flat_tensor_->named_data() == nullptr) {
return flat_tensor_->tensors()->size();
}
return flat_tensor_->named_data()->size() + flat_tensor_->tensors()->size();
}

ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
size_t index) const {
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
return Error::InvalidArgument;
}
// TODO(lfq): consolidate named_data and tensors.
// For now, iterate over named_data and then flat_tensor.
size_t num_keys = get_num_keys().get();
ET_CHECK_OR_RETURN_ERROR(
index >= 0 && index < num_keys,
InvalidArgument,
"Index %zu out of range of size %zu",
index,
num_keys);

return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
if (flat_tensor_->named_data() != nullptr &&
index < flat_tensor_->named_data()->size()) {
return flat_tensor_->named_data()->Get(index)->key()->c_str();
} else {
if (flat_tensor_->named_data() != nullptr) {
index = index - flat_tensor_->named_data()->size();
}
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
}
}

/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
Expand Down
11 changes: 10 additions & 1 deletion runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,15 @@ Error Method::init(
pte_data_map = pte_data_map_res.get();
}

ET_CHECK_OR_RETURN_ERROR(
!(pte_data_map && named_data_map),
NotSupported,
"NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues");

if (!named_data_map || named_data_map->get_num_keys().get() == 0) {
named_data_map = pte_data_map;
}

// n_delegate_ counts the number of successfully-initialized delegates for
// ~Method() to clean up, and is incremented at the bottom of the loop. This
// makes it safe for errors to return without updating any state.
Expand All @@ -816,7 +825,7 @@ Error Method::init(
method_allocator,
/*event_tracer=*/event_tracer_,
/*method_name=*/serialization_plan_->name()->c_str(),
/*named_data_map=*/pte_data_map);
/*named_data_map=*/named_data_map);
Error err = BackendDelegate::Init(
delegate, program_, backend_init_context, &delegates_[i]);
if (err != Error::Ok) {
Expand Down
Loading