Skip to content

Commit 3e289a4

Browse files
committed
[executorch][flat_tensor] DataMap implementation
Pull Request resolved: #7900 DataMap implementation that * Loads a flat_tensor file * Makes tensor information available via the named_data_map.h interface. TODO: in a later diff, update the ET runtime to hold onto the FreeableBuffers returned by the NDM. Then, the NDM will not persist the segment. T214294528 ghstack-source-id: 264871691 Differential Revision: [D67064580](https://our.internmc.facebook.com/intern/diff/D67064580/)
1 parent 15c8bdf commit 3e289a4

File tree

8 files changed

+540
-2
lines changed

8 files changed

+540
-2
lines changed

extension/flat_tensor/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
10+
11+
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
12+
#include <executorch/extension/flat_tensor/serialize/schema_generated.h>
13+
14+
#include <executorch/runtime/core/error.h>
15+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
16+
#include <executorch/runtime/core/freeable_buffer.h>
17+
#include <executorch/runtime/core/result.h>
18+
#include <executorch/runtime/core/span.h>
19+
#include <executorch/runtime/platform/compiler.h>
20+
21+
using executorch::runtime::Error;
22+
using executorch::runtime::FreeableBuffer;
23+
using executorch::runtime::Result;
24+
using executorch::runtime::Span;
25+
26+
using executorch::aten::ScalarType;
27+
using executorch::runtime::DataLoader;
28+
using executorch::runtime::TensorLayout;
29+
30+
namespace executorch {
31+
namespace extension {
32+
33+
namespace {
34+
/**
35+
* FlatTensor data must be aligned to this value to properly parse it. Must be a
36+
* power of 2. Note that max_align_t is the alignment that malloc() and new
37+
* guarantee.
38+
*/
39+
constexpr size_t kMinimumAlignment = alignof(std::max_align_t);
40+
41+
bool is_aligned(const void* data) {
42+
uintptr_t addr = reinterpret_cast<uintptr_t>(data);
43+
return addr % kMinimumAlignment == 0;
44+
}
45+
46+
Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
47+
const char* key,
48+
const flatbuffers::Vector<
49+
flatbuffers::Offset<flat_tensor_flatbuffer::TensorMetadata>>* tensors) {
50+
// Linear search by name.
51+
for (int i = 0; i < tensors->size(); i++) {
52+
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
53+
0) {
54+
// TODO(T214294528): Support multiple segments in FlatTensor.
55+
if (tensors->Get(i)->segment_index() != 0) {
56+
return Error::InvalidExternalData;
57+
}
58+
return tensors->Get(i);
59+
}
60+
}
61+
return Error::NotFound;
62+
}
63+
64+
Result<const TensorLayout> create_tensor_layout(
65+
const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) {
66+
ScalarType scalar_type =
67+
static_cast<ScalarType>(tensor_metadata->scalar_type());
68+
const int dim = tensor_metadata->sizes()->size();
69+
const auto serialized_sizes = tensor_metadata->sizes()->data();
70+
const auto serialized_dim_order = tensor_metadata->dim_order()->data();
71+
return TensorLayout::create(
72+
Span<const int32_t>(serialized_sizes, dim),
73+
Span<const uint8_t>(serialized_dim_order, dim),
74+
scalar_type);
75+
}
76+
77+
} // namespace
78+
79+
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
80+
const char* key) const {
81+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
82+
get_flat_tensor_metadata(key, flat_tensor_->tensors());
83+
if (!metadata_res.ok()) {
84+
return metadata_res.error();
85+
}
86+
return create_tensor_layout(metadata_res.get());
87+
}
88+
89+
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
90+
const char* key) const {
91+
auto tensor_metadata = flat_tensor_->tensors();
92+
93+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
94+
get_flat_tensor_metadata(key, tensor_metadata);
95+
if (!metadata_res.ok()) {
96+
return metadata_res.error();
97+
}
98+
const auto metadata = metadata_res.get();
99+
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
100+
// Invalid segment_index/offset; malformed PTD file.
101+
return Error::InvalidExternalData;
102+
}
103+
104+
Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
105+
if (!tensor_layout_res.ok()) {
106+
return tensor_layout_res.error();
107+
}
108+
109+
// This FreeableBuffer doesn't own the underlying data, and will not free it,
110+
// which is why the free function is a nullptr.
111+
// TODO(T214294528)
112+
return FreeableBuffer(
113+
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
114+
tensor_layout_res.get().nbytes(),
115+
nullptr);
116+
}
117+
118+
ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
119+
ET_UNUSED const char* key,
120+
ET_UNUSED void* buffer,
121+
ET_UNUSED size_t size) const {
122+
return Error::NotImplemented;
123+
}
124+
125+
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
126+
return flat_tensor_->tensors()->size();
127+
}
128+
129+
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
130+
size_t index) const {
131+
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
132+
return Error::InvalidArgument;
133+
}
134+
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
135+
}
136+
137+
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
138+
DataLoader* loader) {
139+
// Load data map.
140+
size_t flatbuffer_offset = 0;
141+
size_t flatbuffer_size = 0;
142+
size_t segment_base_offset = 0;
143+
size_t segment_data_size = 0;
144+
{
145+
// Check header.
146+
Result<FreeableBuffer> header = loader->load(
147+
/*offset=*/0,
148+
FlatTensorHeader::kNumHeadBytes,
149+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
150+
if (!header.ok()) {
151+
return header.error();
152+
}
153+
Result<FlatTensorHeader> fh =
154+
FlatTensorHeader::Parse(header->data(), header->size());
155+
if (fh.ok()) {
156+
// The header has the data map size.
157+
flatbuffer_offset = fh->flatbuffer_offset;
158+
flatbuffer_size = fh->flatbuffer_size;
159+
segment_base_offset = fh->segment_base_offset;
160+
segment_data_size = fh->segment_data_size;
161+
} else if (fh.error() == Error::NotFound) {
162+
// No header, throw error.
163+
ET_LOG(Error, "No FlatTensorHeader found.");
164+
return fh.error();
165+
} else {
166+
// corruption, throw error.
167+
ET_LOG(Error, "Flat tensor header may be corrupt.");
168+
return fh.error();
169+
}
170+
}
171+
172+
// Load flatbuffer data as a segment.
173+
Result<FreeableBuffer> flat_tensor_data = loader->load(
174+
/*offset=*/0,
175+
flatbuffer_offset + flatbuffer_size,
176+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
177+
if (!flat_tensor_data.ok()) {
178+
return flat_tensor_data.error();
179+
}
180+
181+
// Make sure magic matches.
182+
if (!flat_tensor_flatbuffer::FlatTensorBufferHasIdentifier(
183+
flat_tensor_data->data())) {
184+
ET_LOG(
185+
Error,
186+
"FlatTensor identifier '%.4s' != expected '%.4s'",
187+
flatbuffers::GetBufferIdentifier(flat_tensor_data->data()),
188+
flat_tensor_flatbuffer::FlatTensorIdentifier());
189+
return Error::InvalidExternalData;
190+
}
191+
192+
// The flatbuffer data must start at an aligned address to ensure internal
193+
// alignment of flatbuffer fields.
194+
ET_CHECK_OR_RETURN_ERROR(
195+
is_aligned(flat_tensor_data->data()),
196+
InvalidArgument,
197+
"FlatTensor data 0x%p must be aligned to %zu",
198+
flat_tensor_data->data(),
199+
kMinimumAlignment);
200+
201+
// Get pointer to root of flatbuffer table.
202+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
203+
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());
204+
205+
// Validate flatbuffer data.
206+
flatbuffers::Verifier verifier(
207+
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
208+
flat_tensor_data->size());
209+
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
210+
ET_CHECK_OR_RETURN_ERROR(
211+
ok,
212+
InvalidExternalData,
213+
"Verification failed; data may be truncated or corrupt");
214+
215+
// Get pointer to tensor metadata.
216+
const auto* s_tensor_metadata = flat_tensor->tensors();
217+
if (s_tensor_metadata == nullptr) {
218+
ET_LOG(Error, "FlatTensor has no tensor metadata.");
219+
return Error::InvalidExternalData;
220+
}
221+
222+
// Load constant data.
223+
const auto* s_data_segment = flat_tensor->segments();
224+
225+
// TODO(T214294528): Support multiple segments in FlatTensor.
226+
if (s_data_segment->size() != 1) {
227+
ET_LOG(
228+
Error,
229+
"FlatTensor has %u segments, only 1 supported.",
230+
s_data_segment->size());
231+
}
232+
// First segment size should be <= the total segment data size.
233+
int segment_size = s_data_segment->Get(0)->size();
234+
int segment_offset = s_data_segment->Get(0)->offset();
235+
if (segment_size > segment_data_size) {
236+
ET_LOG(
237+
Error,
238+
"FlatTensor segment size %d > segment data size %zu",
239+
segment_size,
240+
segment_data_size);
241+
}
242+
243+
Result<FreeableBuffer> data_ro = loader->load(
244+
/*offset=*/segment_base_offset + segment_offset,
245+
segment_size,
246+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
247+
if (!data_ro.ok()) {
248+
return data_ro.error();
249+
}
250+
251+
return FlatTensorDataMap(
252+
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
253+
}
254+
255+
} // namespace extension
256+
} // namespace executorch
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
#include <executorch/runtime/core/data_loader.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/core/tensor_layout.h>
17+
#include <executorch/runtime/platform/compiler.h>
18+
19+
#include <utility>
20+
21+
// Forward declare flatbuffer types. This is a public header and must not
22+
// include the generated flatbuffer header.
23+
namespace flat_tensor_flatbuffer {
24+
struct FlatTensor;
25+
} // namespace flat_tensor_flatbuffer
26+
27+
namespace executorch {
28+
namespace extension {
29+
30+
/**
31+
* A NamedDataMap implementation for FlatTensor-serialized data.
32+
*/
33+
class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
34+
public:
35+
/**
36+
* Creates a new DataMap that wraps FlatTensor data.
37+
*
38+
* @param[in] loader The DataLoader that wraps the FlatTensor file.
39+
* Note: the loader must outlive the FlatTensorDataMap instance.
40+
*/
41+
static executorch::runtime::Result<FlatTensorDataMap> load(
42+
executorch::runtime::DataLoader* loader);
43+
44+
ET_NODISCARD
45+
executorch::runtime::Result<const executorch::runtime::TensorLayout>
46+
get_metadata(const char* key) const override;
47+
ET_NODISCARD
48+
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
49+
const char* key) const override;
50+
ET_NODISCARD executorch::runtime::Result<size_t>
51+
load_data_into(const char* key, void* buffer, size_t size) const override;
52+
53+
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
54+
const override;
55+
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
56+
size_t index) const override;
57+
58+
FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default;
59+
60+
private:
61+
FlatTensorDataMap(
62+
executorch::runtime::FreeableBuffer&& flat_tensor_data,
63+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
64+
executorch::runtime::FreeableBuffer&& data_ro)
65+
: flat_tensor_data_(std::move(flat_tensor_data)),
66+
flat_tensor_(flat_tensor),
67+
data_ro_(std::move(data_ro)) {}
68+
69+
// Not copyable or assignable.
70+
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
71+
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
72+
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;
73+
74+
// Serialized flat_tensor flatbuffer data.
75+
executorch::runtime::FreeableBuffer flat_tensor_data_;
76+
77+
// Flatbuffer representation of the flat_tensor.
78+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;
79+
80+
// Loaded read-only tensor data.
81+
executorch::runtime::FreeableBuffer data_ro_;
82+
};
83+
84+
} // namespace extension
85+
} // namespace executorch

extension/flat_tensor/targets.bzl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.cxx_library(
5+
name = "flat_tensor_data_map",
6+
srcs = [
7+
"flat_tensor_data_map.cpp",
8+
],
9+
exported_headers = ["flat_tensor_data_map.h"],
10+
deps = [
11+
"//executorch/extension/flat_tensor/serialize:generated_headers",
12+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
13+
"//executorch/runtime/core:core",
14+
"//executorch/runtime/core:evalue",
15+
"//executorch/runtime/core:named_data_map",
16+
"//executorch/runtime/core/exec_aten:lib",
17+
"//executorch/runtime/core/exec_aten/util:tensor_util",
18+
],
19+
visibility = [
20+
"//executorch/...",
21+
],
22+
)

extension/flat_tensor/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ load(":targets.bzl", "define_common_targets")
66

77
oncall("executorch")
88

9-
define_common_targets()
9+
define_common_targets(is_fbcode=True)
1010

1111
python_unittest(
1212
name = "serialize",

0 commit comments

Comments
 (0)