Skip to content

Commit 1378480

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
More overflow checks in memory allocator. (#15581)
Summary: - Safer overflow math for high alignments - Only over-allocate up to alignment - 1 bytes when a higher-than-malloc alignment is requested - Move EXECUTORCH_TRACK_ALLOCATION to after a successful allocation and use the final size Differential Revision: D86228757
1 parent 1fcba99 commit 1378480

File tree

2 files changed

+43
-36
lines changed

2 files changed

+43
-36
lines changed

extension/memory_allocator/malloc_memory_allocator.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cstddef>
1212
#include <cstdint>
13+
#include <cstdlib>
1314
#include <vector>
1415

1516
#include <executorch/runtime/core/memory_allocator.h>
@@ -45,8 +46,6 @@ class MallocMemoryAllocator : public executorch::runtime::MemoryAllocator {
4546
* memory alignment size.
4647
*/
4748
void* allocate(size_t size, size_t alignment = kDefaultAlignment) override {
48-
EXECUTORCH_TRACK_ALLOCATION(prof_id(), size);
49-
5049
if (!isPowerOf2(alignment)) {
5150
ET_LOG(Error, "Alignment %zu is not a power of 2", alignment);
5251
return nullptr;
@@ -56,30 +55,29 @@ class MallocMemoryAllocator : public executorch::runtime::MemoryAllocator {
5655
static constexpr size_t kMallocAlignment = alignof(std::max_align_t);
5756
if (alignment > kMallocAlignment) {
5857
// To get higher alignments, allocate extra and then align the returned
59-
// pointer. This will waste an extra `alignment` bytes every time, but
58+
// pointer. This will waste an extra `alignment - 1` bytes every time, but
6059
// this is the only portable way to get aligned memory from the heap.
61-
62-
// Check for overflow before adding alignment to size
63-
if (size > SIZE_MAX - alignment) {
64-
ET_LOG(
65-
Error, "Size %zu + alignment %zu would overflow", size, alignment);
60+
const size_t extra = alignment - 1;
61+
if (ET_UNLIKELY(extra > SIZE_MAX - size)) {
62+
ET_LOG(Error, "Malloc size overflow: size=%zu + extra=%zu", size, extra);
6663
return nullptr;
6764
}
68-
size += alignment;
65+
size += extra;
6966
}
7067
void* mem_ptr = std::malloc(size);
71-
if (mem_ptr == nullptr) {
72-
ET_LOG(Error, "Failed to allocate %zu bytes", size);
68+
if (!mem_ptr) {
69+
ET_LOG(Error, "Malloc failed to allocate %zu bytes", size);
7370
return nullptr;
7471
}
7572
mem_ptrs_.emplace_back(mem_ptr);
73+
EXECUTORCH_TRACK_ALLOCATION(prof_id(), size);
7674
return alignPointer(mem_ptrs_.back(), alignment);
7775
}
7876

7977
// Free up each hosted memory pointer. The memory was created via malloc.
8078
void reset() override {
8179
for (auto mem_ptr : mem_ptrs_) {
82-
free(mem_ptr);
80+
std::free(mem_ptr);
8381
}
8482
mem_ptrs_.clear();
8583
}

runtime/core/memory_allocator.h

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#pragma once
1010

1111
#include <stdio.h>
12-
#include <cinttypes>
1312
#include <cstdint>
1413

1514
#include <c10/util/safe_numerics.h>
@@ -59,9 +58,20 @@ class MemoryAllocator {
5958
*/
6059
MemoryAllocator(uint32_t size, uint8_t* base_address)
6160
: begin_(base_address),
62-
end_(base_address + size),
61+
end_(base_address ?
62+
(UINTPTR_MAX - reinterpret_cast<uintptr_t>(base_address) >= size ?
63+
base_address + size : nullptr) : nullptr),
6364
cur_(base_address),
64-
size_(size) {}
65+
size_(size) {
66+
ET_CHECK_MSG(base_address || size == 0, "Base address is null but size=%u", size);
67+
ET_CHECK_MSG(!base_address || size == 0 ||
68+
(UINTPTR_MAX - reinterpret_cast<uintptr_t>(base_address) >= size),
69+
"Address space overflow in allocator");
70+
}
71+
72+
MemoryAllocator(const MemoryAllocator&) = delete;
73+
MemoryAllocator& operator=(const MemoryAllocator&) = delete;
74+
virtual ~MemoryAllocator() = default;
6575

6676
/**
6777
* Allocates `size` bytes of memory.
@@ -74,6 +84,10 @@ class MemoryAllocator {
7484
* @retval nullptr Not enough memory, or `alignment` was not a power of 2.
7585
*/
7686
virtual void* allocate(size_t size, size_t alignment = kDefaultAlignment) {
87+
if (ET_UNLIKELY(!begin_ || !end_)) {
88+
ET_LOG(Error, "allocate() on zero-capacity allocator");
89+
return nullptr;
90+
}
7791
if (!isPowerOf2(alignment)) {
7892
ET_LOG(Error, "Alignment %zu is not a power of 2", alignment);
7993
return nullptr;
@@ -82,18 +96,17 @@ class MemoryAllocator {
8296
// The allocation will occupy [start, end), where the start is the next
8397
// position that's a multiple of alignment.
8498
uint8_t* start = alignPointer(cur_, alignment);
85-
uint8_t* end = start + size;
86-
87-
// If the end of this allocation exceeds the end of this allocator, print
88-
// error messages and return nullptr
89-
if (end > end_ || end < start) {
99+
size_t padding = static_cast<size_t>(start - cur_);
100+
size_t available = static_cast<size_t>(end_ - cur_);
101+
if (ET_UNLIKELY(padding > available || size > available || size > available - padding)) {
90102
ET_LOG(
91103
Error,
92104
"Memory allocation failed: %zuB requested (adjusted for alignment), %zuB available",
93-
static_cast<size_t>(end - cur_),
94-
static_cast<size_t>(end_ - cur_));
105+
padding + size,
106+
available);
95107
return nullptr;
96108
}
109+
uint8_t* end = start + size;
97110

98111
// Otherwise, record how many bytes were used, advance cur_ to the new end,
99112
// and then return start. Note that the number of bytes used is (end - cur_)
@@ -144,8 +157,9 @@ class MemoryAllocator {
144157
if (overflow) {
145158
ET_LOG(
146159
Error,
147-
"Failed to allocate list of type %zu: size * sizeof(T) overflowed",
148-
size);
160+
"Failed to allocate list: size(%zu) * sizeof(T)(%zu) overflowed",
161+
size,
162+
sizeof(T));
149163
return nullptr;
150164
}
151165
return static_cast<T*>(this->allocate(bytes_size, alignment));
@@ -171,8 +185,6 @@ class MemoryAllocator {
171185
prof_id_ = EXECUTORCH_TRACK_ALLOCATOR(name);
172186
}
173187

174-
virtual ~MemoryAllocator() {}
175-
176188
protected:
177189
/**
178190
* Returns the profiler ID for this allocator.
@@ -184,21 +196,18 @@ class MemoryAllocator {
184196
/**
185197
* Returns true if the value is an integer power of 2.
186198
*/
187-
static bool isPowerOf2(size_t value) {
188-
return value > 0 && (value & ~(value - 1)) == value;
199+
static constexpr bool isPowerOf2(size_t value) {
200+
return value && !(value & (value - 1));
189201
}
190202

191203
/**
192204
* Returns the next alignment for a given pointer.
193205
*/
194-
static uint8_t* alignPointer(void* ptr, size_t alignment) {
195-
intptr_t addr = reinterpret_cast<intptr_t>(ptr);
196-
if ((addr & (alignment - 1)) == 0) {
197-
// Already aligned.
198-
return reinterpret_cast<uint8_t*>(ptr);
199-
}
200-
addr = (addr | (alignment - 1)) + 1;
201-
return reinterpret_cast<uint8_t*>(addr);
206+
static inline uint8_t* alignPointer(void* ptr, size_t alignment) {
207+
uintptr_t address = reinterpret_cast<uintptr_t>(ptr);
208+
uintptr_t mask = static_cast<uintptr_t>(alignment - 1);
209+
address = (address + mask) & ~mask;
210+
return reinterpret_cast<uint8_t*>(address);
202211
}
203212

204213
private:

0 commit comments

Comments
 (0)