Skip to content

Commit 728c255

Browse files
kirklandsignlucylq
andauthored
[executorch][flat_tensor] implement load into and dont hold onto the segment (#8650)
* [flat_tensor] Persist FreeableBuffers of external constants in method Pull Request resolved: #8437 ## Problem Currently, the FlatTensorDataMap persists tensors, and returns a FreeableBuffer with an empty free function. The NamedDataMap should not persist data, as most cases (eg. delegate) will want it to be freed. Ownership should be on the caller; `get_data` returns a FreeableBuffer that 'owns' the data. The FreeableBuffer in turn is owned by the caller. NOTE: this doesn't support the case where we want to share plain tensors between methods/pte files at runtime. A custom NDM could support that use-case. ## This diff: 1. Introduces a 'NamedData' struct to method.h. This holds a key and a FreeeableBuffer. 2. Iterate over all the flatbuffer tensors to count the constants tagged with EXTERNAL. NOTE: this will increase load time for all users. Potentially allocate chunks of 16 and use a linked list to store external constants, or store this number in PTE file (see D69618283). 3. Allocate space for num_external_constants using the method allocator. 4. Iterate over all flatbuffer tensors and use the named_data_map to resolve EXTERNAL tensors into the array of NamedData. 5. Pass the resolved external constants to tensor_parser, along with NDM (used for mutable external tensors). 6. Resolved external tensors are stored inside method. They are freed when the method is destructed. Some notes: https://docs.google.com/document/d/1_PBi4JgODuClUPD4PCUWrKNjyUH54zOUHGUJ3QHDNes/edit?tab=t.0#heading=h.blsvwraxss7g ghstack-source-id: 267364187 TODO: add test case when two fqns point to the same data buffer. Differential Revision: [D69477027](https://our.internmc.facebook.com/intern/diff/D69477027/) * [executorch][flat_tensor] implement load into and dont hold onto the segment Pull Request resolved: #8447 1. Implement load_into in FlatTensorDataMap 2. Do not persist 'data_ro' in the FlatTensorDataMap. From `get_data`, return the FreeableBuffer given by the data loader. TODO: add test for load_into. ghstack-source-id: 267467148 Differential Revision: [D69148652](https://our.internmc.facebook.com/intern/diff/D69148652/) --------- Co-authored-by: lucylq <lfq@meta.com>
1 parent 9c51e58 commit 728c255

File tree

3 files changed

+150
-113
lines changed

3 files changed

+150
-113
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

+104-107
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252
for (int i = 0; i < tensors->size(); i++) {
5353
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
5454
0) {
55-
// TODO(T214294528): Support multiple segments in FlatTensor.
56-
if (tensors->Get(i)->segment_index() != 0) {
57-
return Error::InvalidExternalData;
58-
}
59-
return tensors->Get(i);
55+
const auto* metadata = tensors->Get(i);
56+
ET_CHECK_OR_RETURN_ERROR(
57+
metadata->segment_index() >= 0 && metadata->offset() >= 0,
58+
InvalidExternalData,
59+
"Invalid segment_index %d or offset %" PRIu64 "; malformed PTD file.",
60+
metadata->segment_index(),
61+
metadata->offset());
62+
return metadata;
6063
}
6164
}
6265
return Error::NotFound;
@@ -75,6 +78,23 @@ Result<const TensorLayout> create_tensor_layout(
7578
scalar_type);
7679
}
7780

81+
Result<int> get_and_check_segment_offset(
82+
const flatbuffers::Vector<
83+
flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
84+
const flat_tensor_flatbuffer::TensorMetadata* metadata) {
85+
ET_CHECK_OR_RETURN_ERROR(
86+
segments != nullptr,
87+
InvalidExternalData,
88+
"No segments in external data flatbuffer.");
89+
90+
ET_CHECK_OR_RETURN_ERROR(
91+
metadata->segment_index() < segments->size(),
92+
InvalidExternalData,
93+
"Invalid segment_index %d; malformed PTD file.",
94+
metadata->segment_index());
95+
return segments->Get(metadata->segment_index())->offset();
96+
}
97+
7898
} // namespace
7999

80100
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
@@ -89,39 +109,73 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
89109

90110
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
91111
const char* key) const {
92-
auto tensor_metadata = flat_tensor_->tensors();
93-
94-
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
95-
get_flat_tensor_metadata(key, tensor_metadata);
96-
if (!metadata_res.ok()) {
97-
return metadata_res.error();
112+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113+
get_flat_tensor_metadata(key, flat_tensor_->tensors());
114+
if (!metadata.ok()) {
115+
return metadata.error();
98116
}
99-
const auto metadata = metadata_res.get();
100-
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
101-
// Invalid segment_index/offset; malformed PTD file.
102-
return Error::InvalidExternalData;
117+
Result<const TensorLayout> tensor_layout =
118+
create_tensor_layout(metadata.get());
119+
if (!tensor_layout.ok()) {
120+
return tensor_layout.error();
103121
}
104-
105-
Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
106-
if (!tensor_layout_res.ok()) {
107-
return tensor_layout_res.error();
122+
Result<int> segment_offset =
123+
get_and_check_segment_offset(flat_tensor_->segments(), metadata.get());
124+
if (!segment_offset.ok()) {
125+
return segment_offset.error();
108126
}
109127

110-
// This FreeableBuffer doesn't own the underlying data, and will not free it,
111-
// which is why the free function is a nullptr.
112-
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113-
// FreeableBuffer own it.
114-
return FreeableBuffer(
115-
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
116-
tensor_layout_res.get().nbytes(),
117-
nullptr);
128+
// Load constant data.
129+
ET_CHECK_OR_RETURN_ERROR(
130+
segment_offset.get() <
131+
header_.segment_base_offset + header_.segment_data_size,
132+
InvalidExternalData,
133+
"Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %" PRIu64
134+
"; malformed PTD file.",
135+
segment_offset.get(),
136+
header_.segment_base_offset + header_.segment_data_size);
137+
return loader_->load(
138+
header_.segment_base_offset + segment_offset.get() +
139+
metadata.get()->offset(),
140+
tensor_layout.get().nbytes(),
141+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
118142
}
119143

120144
ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
121145
ET_UNUSED const char* key,
122146
ET_UNUSED void* buffer,
123147
ET_UNUSED size_t size) const {
124-
return Error::NotImplemented;
148+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
149+
get_flat_tensor_metadata(key, flat_tensor_->tensors());
150+
if (!metadata.ok()) {
151+
return metadata.error();
152+
}
153+
Result<const TensorLayout> tensor_layout =
154+
create_tensor_layout(metadata.get());
155+
if (!tensor_layout.ok()) {
156+
return tensor_layout.error();
157+
}
158+
ET_CHECK_OR_RETURN_ERROR(
159+
size < tensor_layout.get().nbytes(),
160+
InvalidArgument,
161+
"Buffer size %zu is smaller than tensor size %zu",
162+
size,
163+
tensor_layout.get().nbytes());
164+
165+
Result<int> segment_offset =
166+
get_and_check_segment_offset(flat_tensor_->segments(), metadata.get());
167+
if (!segment_offset.ok()) {
168+
return segment_offset.error();
169+
}
170+
// Load mutable data.
171+
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
172+
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
173+
return loader_->load_into(
174+
header_.segment_base_offset + segment_offset.get() +
175+
metadata.get()->offset(),
176+
tensor_layout.get().nbytes(),
177+
info,
178+
buffer);
125179
}
126180

127181
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
@@ -138,45 +192,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138192

139193
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
140194
DataLoader* loader) {
141-
// Load data map.
142-
size_t flatbuffer_offset = 0;
143-
size_t flatbuffer_size = 0;
144-
size_t segment_base_offset = 0;
145-
size_t segment_data_size = 0;
146-
{
147-
// Check header.
148-
Result<FreeableBuffer> header = loader->load(
149-
/*offset=*/0,
150-
FlatTensorHeader::kNumHeadBytes,
151-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
152-
if (!header.ok()) {
153-
return header.error();
154-
}
155-
Result<FlatTensorHeader> fh =
156-
FlatTensorHeader::Parse(header->data(), header->size());
157-
if (fh.ok()) {
158-
// The header has the data map size.
159-
flatbuffer_offset = fh->flatbuffer_offset;
160-
flatbuffer_size = fh->flatbuffer_size;
161-
segment_base_offset = fh->segment_base_offset;
162-
segment_data_size = fh->segment_data_size;
163-
} else if (fh.error() == Error::NotFound) {
164-
// No header, throw error.
165-
ET_LOG(Error, "No FlatTensorHeader found.");
166-
return fh.error();
167-
} else {
168-
// corruption, throw error.
169-
ET_LOG(Error, "Flat tensor header may be corrupt.");
170-
return fh.error();
171-
}
195+
// Check header.
196+
Result<FreeableBuffer> header = loader->load(
197+
/*offset=*/0,
198+
FlatTensorHeader::kNumHeadBytes,
199+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
200+
if (!header.ok()) {
201+
ET_LOG(Error, "Failed to load header.");
202+
return header.error();
203+
}
204+
Result<FlatTensorHeader> fh =
205+
FlatTensorHeader::Parse(header->data(), header->size());
206+
if (fh.error() == Error::NotFound) {
207+
// No header, throw error.
208+
ET_LOG(Error, "No FlatTensorHeader found.");
209+
return fh.error();
210+
} else if (fh.error() != Error::Ok) {
211+
// corruption, throw error.
212+
ET_LOG(Error, "Flat tensor header may be corrupt.");
213+
return fh.error();
172214
}
173215

174216
// Load flatbuffer data as a segment.
175217
Result<FreeableBuffer> flat_tensor_data = loader->load(
176218
/*offset=*/0,
177-
flatbuffer_offset + flatbuffer_size,
219+
fh->flatbuffer_offset + fh->flatbuffer_size,
178220
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
179221
if (!flat_tensor_data.ok()) {
222+
ET_LOG(Error, "Failed to load flat_tensor data.");
180223
return flat_tensor_data.error();
181224
}
182225

@@ -204,54 +247,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204247
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205248
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());
206249

207-
// Validate flatbuffer data.
208-
flatbuffers::Verifier verifier(
209-
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
210-
flat_tensor_data->size());
211-
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
212-
ET_CHECK_OR_RETURN_ERROR(
213-
ok,
214-
InvalidExternalData,
215-
"Verification failed; data may be truncated or corrupt");
216-
217-
// Get pointer to tensor metadata.
218-
const auto* s_tensor_metadata = flat_tensor->tensors();
219-
if (s_tensor_metadata == nullptr) {
220-
ET_LOG(Error, "FlatTensor has no tensor metadata.");
221-
return Error::InvalidExternalData;
222-
}
223-
224-
// Load constant data.
225-
const auto* s_data_segment = flat_tensor->segments();
226-
227-
// TODO(T214294528): Support multiple segments in FlatTensor.
228-
if (s_data_segment->size() != 1) {
229-
ET_LOG(
230-
Error,
231-
"FlatTensor has %u segments, only 1 supported.",
232-
s_data_segment->size());
233-
}
234-
// First segment size should be <= the total segment data size.
235-
int segment_size = s_data_segment->Get(0)->size();
236-
int segment_offset = s_data_segment->Get(0)->offset();
237-
if (segment_size > segment_data_size) {
238-
ET_LOG(
239-
Error,
240-
"FlatTensor segment size %d > segment data size %zu",
241-
segment_size,
242-
segment_data_size);
243-
}
244-
245-
Result<FreeableBuffer> data_ro = loader->load(
246-
/*offset=*/segment_base_offset + segment_offset,
247-
segment_size,
248-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
249-
if (!data_ro.ok()) {
250-
return data_ro.error();
251-
}
252-
253250
return FlatTensorDataMap(
254-
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
251+
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
255252
}
256253

257254
} // namespace extension

extension/flat_tensor/flat_tensor_data_map.h

+45-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/runtime/core/named_data_map.h>
1212

13+
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
14+
1315
#include <executorch/runtime/core/data_loader.h>
1416
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1517
#include <executorch/runtime/core/result.h>
@@ -41,17 +43,50 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
4143
static executorch::runtime::Result<FlatTensorDataMap> load(
4244
executorch::runtime::DataLoader* loader);
4345

46+
/**
47+
* Retrieve the metadata for the specified key.
48+
*
49+
* @param[in] key The name of the tensor to get metadata on.
50+
*
51+
* @return Error::NotFound if the key is not present.
52+
*/
4453
ET_NODISCARD
4554
executorch::runtime::Result<const executorch::runtime::TensorLayout>
4655
get_metadata(const char* key) const override;
56+
57+
/**
58+
* Retrieve read-only data for the specified key.
59+
*
60+
* @param[in] key The name of the tensor to get data on.
61+
*
62+
* @return error if the key is not present or data cannot be loaded.
63+
*/
4764
ET_NODISCARD
4865
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
4966
const char* key) const override;
67+
68+
/**
69+
* Loads the data of the specified tensor into the provided buffer.
70+
*
71+
* @param[in] key The name of the tensor to get the data of.
72+
* @param[in] buffer The buffer to load data into. Must point to at least
73+
* `size` bytes of memory.
74+
* @param[in] size The number of bytes to load.
75+
*
76+
* @returns an Error indicating if the load was successful.
77+
*/
5078
ET_NODISCARD executorch::runtime::Result<size_t>
5179
load_data_into(const char* key, void* buffer, size_t size) const override;
5280

81+
/**
82+
* @returns The number of keys in the map.
83+
*/
5384
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
5485
const override;
86+
87+
/**
88+
* @returns The key at the specified index, error if index out of bounds.
89+
*/
5590
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
5691
size_t index) const override;
5792

@@ -61,26 +96,31 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
6196

6297
private:
6398
FlatTensorDataMap(
99+
const FlatTensorHeader& header,
64100
executorch::runtime::FreeableBuffer&& flat_tensor_data,
65101
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
66-
executorch::runtime::FreeableBuffer&& data_ro)
67-
: flat_tensor_data_(std::move(flat_tensor_data)),
102+
executorch::runtime::DataLoader* loader)
103+
: header_(header),
104+
flat_tensor_data_(std::move(flat_tensor_data)),
68105
flat_tensor_(flat_tensor),
69-
data_ro_(std::move(data_ro)) {}
106+
loader_(loader) {}
70107

71108
// Not copyable or assignable.
72109
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
73110
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
74111
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;
75112

113+
// FlatTensor header, containing segment_base_offset and segment_data_size.
114+
const FlatTensorHeader header_;
115+
76116
// Serialized flat_tensor flatbuffer data.
77117
executorch::runtime::FreeableBuffer flat_tensor_data_;
78118

79119
// Flatbuffer representation of the flat_tensor.
80120
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;
81121

82-
// Loaded read-only tensor data.
83-
executorch::runtime::FreeableBuffer data_ro_;
122+
// Data loader, used to load segment data.
123+
executorch::runtime::DataLoader* loader_;
84124
};
85125

86126
} // namespace extension

extension/flat_tensor/test/targets.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def define_common_targets(is_fbcode=False):
4040
}
4141

4242
runtime.cxx_test(
43-
name = "flat_tensor_data_map",
43+
name = "flat_tensor_data_map_test",
4444
srcs = [
4545
"flat_tensor_data_map_test.cpp",
4646
],

0 commit comments

Comments
 (0)