Skip to content

[executorch][flat_tensor] implement load into and dont hold onto the segment #8650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 104 additions & 107 deletions extension/flat_tensor/flat_tensor_data_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
for (int i = 0; i < tensors->size(); i++) {
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
0) {
// TODO(T214294528): Support multiple segments in FlatTensor.
if (tensors->Get(i)->segment_index() != 0) {
return Error::InvalidExternalData;
}
return tensors->Get(i);
const auto* metadata = tensors->Get(i);
ET_CHECK_OR_RETURN_ERROR(
metadata->segment_index() >= 0 && metadata->offset() >= 0,
InvalidExternalData,
"Invalid segment_index %d or offset %" PRIu64 "; malformed PTD file.",
metadata->segment_index(),
metadata->offset());
return metadata;
}
}
return Error::NotFound;
Expand All @@ -75,6 +78,23 @@ Result<const TensorLayout> create_tensor_layout(
scalar_type);
}

Result<int> get_and_check_segment_offset(
const flatbuffers::Vector<
flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
const flat_tensor_flatbuffer::TensorMetadata* metadata) {
ET_CHECK_OR_RETURN_ERROR(
segments != nullptr,
InvalidExternalData,
"No segments in external data flatbuffer.");

ET_CHECK_OR_RETURN_ERROR(
metadata->segment_index() < segments->size(),
InvalidExternalData,
"Invalid segment_index %d; malformed PTD file.",
metadata->segment_index());
return segments->Get(metadata->segment_index())->offset();
}

} // namespace

ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
Expand All @@ -89,39 +109,73 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(

ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
const char* key) const {
auto tensor_metadata = flat_tensor_->tensors();

Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
get_flat_tensor_metadata(key, tensor_metadata);
if (!metadata_res.ok()) {
return metadata_res.error();
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
get_flat_tensor_metadata(key, flat_tensor_->tensors());
if (!metadata.ok()) {
return metadata.error();
}
const auto metadata = metadata_res.get();
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
// Invalid segment_index/offset; malformed PTD file.
return Error::InvalidExternalData;
Result<const TensorLayout> tensor_layout =
create_tensor_layout(metadata.get());
if (!tensor_layout.ok()) {
return tensor_layout.error();
}

Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
if (!tensor_layout_res.ok()) {
return tensor_layout_res.error();
Result<int> segment_offset =
get_and_check_segment_offset(flat_tensor_->segments(), metadata.get());
if (!segment_offset.ok()) {
return segment_offset.error();
}

// This FreeableBuffer doesn't own the underlying data, and will not free it,
// which is why the free function is a nullptr.
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
// FreeableBuffer own it.
return FreeableBuffer(
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
tensor_layout_res.get().nbytes(),
nullptr);
// Load constant data.
ET_CHECK_OR_RETURN_ERROR(
segment_offset.get() <
header_.segment_base_offset + header_.segment_data_size,
InvalidExternalData,
"Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %" PRIu64
"; malformed PTD file.",
segment_offset.get(),
header_.segment_base_offset + header_.segment_data_size);
return loader_->load(
header_.segment_base_offset + segment_offset.get() +
metadata.get()->offset(),
tensor_layout.get().nbytes(),
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
ET_UNUSED const char* key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const {
return Error::NotImplemented;
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
get_flat_tensor_metadata(key, flat_tensor_->tensors());
if (!metadata.ok()) {
return metadata.error();
}
Result<const TensorLayout> tensor_layout =
create_tensor_layout(metadata.get());
if (!tensor_layout.ok()) {
return tensor_layout.error();
}
ET_CHECK_OR_RETURN_ERROR(
size < tensor_layout.get().nbytes(),
InvalidArgument,
"Buffer size %zu is smaller than tensor size %zu",
size,
tensor_layout.get().nbytes());

Result<int> segment_offset =
get_and_check_segment_offset(flat_tensor_->segments(), metadata.get());
if (!segment_offset.ok()) {
return segment_offset.error();
}
// Load mutable data.
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
return loader_->load_into(
header_.segment_base_offset + segment_offset.get() +
metadata.get()->offset(),
tensor_layout.get().nbytes(),
info,
buffer);
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
Expand All @@ -138,45 +192,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(

/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
DataLoader* loader) {
// Load data map.
size_t flatbuffer_offset = 0;
size_t flatbuffer_size = 0;
size_t segment_base_offset = 0;
size_t segment_data_size = 0;
{
// Check header.
Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
FlatTensorHeader::kNumHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!header.ok()) {
return header.error();
}
Result<FlatTensorHeader> fh =
FlatTensorHeader::Parse(header->data(), header->size());
if (fh.ok()) {
// The header has the data map size.
flatbuffer_offset = fh->flatbuffer_offset;
flatbuffer_size = fh->flatbuffer_size;
segment_base_offset = fh->segment_base_offset;
segment_data_size = fh->segment_data_size;
} else if (fh.error() == Error::NotFound) {
// No header, throw error.
ET_LOG(Error, "No FlatTensorHeader found.");
return fh.error();
} else {
// corruption, throw error.
ET_LOG(Error, "Flat tensor header may be corrupt.");
return fh.error();
}
// Check header.
Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
FlatTensorHeader::kNumHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!header.ok()) {
ET_LOG(Error, "Failed to load header.");
return header.error();
}
Result<FlatTensorHeader> fh =
FlatTensorHeader::Parse(header->data(), header->size());
if (fh.error() == Error::NotFound) {
// No header, throw error.
ET_LOG(Error, "No FlatTensorHeader found.");
return fh.error();
} else if (fh.error() != Error::Ok) {
// corruption, throw error.
ET_LOG(Error, "Flat tensor header may be corrupt.");
return fh.error();
}

// Load flatbuffer data as a segment.
Result<FreeableBuffer> flat_tensor_data = loader->load(
/*offset=*/0,
flatbuffer_offset + flatbuffer_size,
fh->flatbuffer_offset + fh->flatbuffer_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!flat_tensor_data.ok()) {
ET_LOG(Error, "Failed to load flat_tensor data.");
return flat_tensor_data.error();
}

Expand Down Expand Up @@ -204,54 +247,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());

// Validate flatbuffer data.
flatbuffers::Verifier verifier(
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
flat_tensor_data->size());
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
ET_CHECK_OR_RETURN_ERROR(
ok,
InvalidExternalData,
"Verification failed; data may be truncated or corrupt");

// Get pointer to tensor metadata.
const auto* s_tensor_metadata = flat_tensor->tensors();
if (s_tensor_metadata == nullptr) {
ET_LOG(Error, "FlatTensor has no tensor metadata.");
return Error::InvalidExternalData;
}

// Load constant data.
const auto* s_data_segment = flat_tensor->segments();

// TODO(T214294528): Support multiple segments in FlatTensor.
if (s_data_segment->size() != 1) {
ET_LOG(
Error,
"FlatTensor has %u segments, only 1 supported.",
s_data_segment->size());
}
// First segment size should be <= the total segment data size.
int segment_size = s_data_segment->Get(0)->size();
int segment_offset = s_data_segment->Get(0)->offset();
if (segment_size > segment_data_size) {
ET_LOG(
Error,
"FlatTensor segment size %d > segment data size %zu",
segment_size,
segment_data_size);
}

Result<FreeableBuffer> data_ro = loader->load(
/*offset=*/segment_base_offset + segment_offset,
segment_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!data_ro.ok()) {
return data_ro.error();
}

return FlatTensorDataMap(
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
}

} // namespace extension
Expand Down
50 changes: 45 additions & 5 deletions extension/flat_tensor/flat_tensor_data_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <executorch/runtime/core/named_data_map.h>

#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>

#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
Expand Down Expand Up @@ -41,17 +43,50 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
static executorch::runtime::Result<FlatTensorDataMap> load(
executorch::runtime::DataLoader* loader);

/**
* Retrieve the metadata for the specified key.
*
* @param[in] key The name of the tensor to get metadata on.
*
* @return Error::NotFound if the key is not present.
*/
ET_NODISCARD
executorch::runtime::Result<const executorch::runtime::TensorLayout>
get_metadata(const char* key) const override;

/**
* Retrieve read-only data for the specified key.
*
* @param[in] key The name of the tensor to get data on.
*
* @return error if the key is not present or data cannot be loaded.
*/
ET_NODISCARD
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
const char* key) const override;

/**
* Loads the data of the specified tensor into the provided buffer.
*
* @param[in] key The name of the tensor to get the data of.
* @param[in] buffer The buffer to load data into. Must point to at least
* `size` bytes of memory.
* @param[in] size The number of bytes to load.
*
* @returns an Error indicating if the load was successful.
*/
ET_NODISCARD executorch::runtime::Result<size_t>
load_data_into(const char* key, void* buffer, size_t size) const override;

/**
* @returns The number of keys in the map.
*/
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
const override;

/**
* @returns The key at the specified index, error if index out of bounds.
*/
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
size_t index) const override;

Expand All @@ -61,26 +96,31 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {

private:
FlatTensorDataMap(
const FlatTensorHeader& header,
executorch::runtime::FreeableBuffer&& flat_tensor_data,
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
executorch::runtime::FreeableBuffer&& data_ro)
: flat_tensor_data_(std::move(flat_tensor_data)),
executorch::runtime::DataLoader* loader)
: header_(header),
flat_tensor_data_(std::move(flat_tensor_data)),
flat_tensor_(flat_tensor),
data_ro_(std::move(data_ro)) {}
loader_(loader) {}

// Not copyable or assignable.
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;

// FlatTensor header, containing segment_base_offset and segment_data_size.
const FlatTensorHeader header_;

// Serialized flat_tensor flatbuffer data.
executorch::runtime::FreeableBuffer flat_tensor_data_;

// Flatbuffer representation of the flat_tensor.
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;

// Loaded read-only tensor data.
executorch::runtime::FreeableBuffer data_ro_;
// Data loader, used to load segment data.
executorch::runtime::DataLoader* loader_;
};

} // namespace extension
Expand Down
2 changes: 1 addition & 1 deletion extension/flat_tensor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def define_common_targets(is_fbcode=False):
}

runtime.cxx_test(
name = "flat_tensor_data_map",
name = "flat_tensor_data_map_test",
srcs = [
"flat_tensor_data_map_test.cpp",
],
Expand Down
Loading
Loading