88
99#pragma once
1010
11- #include < stdio.h>
1211#include < cinttypes>
13- #include < cstdint>
1412
1513#include < c10/util/safe_numerics.h>
1614
1715#include < executorch/runtime/core/error.h>
1816#include < executorch/runtime/platform/assert.h>
19- #include < executorch/runtime/platform/compiler.h>
20- #include < executorch/runtime/platform/log.h>
2117#include < executorch/runtime/platform/profiler.h>
2218
2319namespace executorch {
@@ -59,9 +55,30 @@ class MemoryAllocator {
5955 */
6056 MemoryAllocator (uint32_t size, uint8_t * base_address)
6157 : begin_(base_address),
62- end_ (base_address + size),
58+ end_ (
59+ base_address
60+ ? (UINTPTR_MAX - reinterpret_cast <uintptr_t >(base_address) >=
61+ size
62+ ? base_address + size
63+ : nullptr)
64+ : nullptr),
6365 cur_(base_address),
64- size_(size) {}
66+ size_(size) {
67+ ET_CHECK_MSG (
68+ base_address || size == 0 ,
69+ " Base address is null but size=%" PRIu32,
70+ size);
71+ ET_CHECK_MSG (
72+ !base_address || size == 0 ||
73+ (UINTPTR_MAX - reinterpret_cast <uintptr_t >(base_address) >= size),
74+ " Address space overflow in allocator" );
75+ }
76+
77+ MemoryAllocator (const MemoryAllocator&) = delete;
78+ MemoryAllocator& operator =(const MemoryAllocator&) = delete ;
79+ MemoryAllocator (MemoryAllocator&&) = delete;
80+ MemoryAllocator& operator =(MemoryAllocator&&) = delete ;
81+ virtual ~MemoryAllocator () = default ;
6582
6683 /* *
6784 * Allocates `size` bytes of memory.
@@ -74,6 +91,10 @@ class MemoryAllocator {
7491 * @retval nullptr Not enough memory, or `alignment` was not a power of 2.
7592 */
7693 virtual void * allocate (size_t size, size_t alignment = kDefaultAlignment ) {
94+ if ET_UNLIKELY (!begin_ || !end_) {
95+ ET_LOG (Error, " allocate() on zero-capacity allocator" );
96+ return nullptr ;
97+ }
7798 if (!isPowerOf2 (alignment)) {
7899 ET_LOG (Error, " Alignment %zu is not a power of 2" , alignment);
79100 return nullptr ;
@@ -82,18 +103,17 @@ class MemoryAllocator {
82103 // The allocation will occupy [start, end), where the start is the next
83104 // position that's a multiple of alignment.
84105 uint8_t * start = alignPointer (cur_, alignment);
85- uint8_t * end = start + size;
86-
87- // If the end of this allocation exceeds the end of this allocator, print
88- // error messages and return nullptr
89- if (end > end_ || end < start) {
106+ size_t padding = static_cast <size_t >(start - cur_);
107+ size_t available = static_cast <size_t >(end_ - cur_);
108+ if ET_UNLIKELY (padding > available || size > available - padding) {
90109 ET_LOG (
91110 Error,
92111 " Memory allocation failed: %zuB requested (adjusted for alignment), %zuB available" ,
93- static_cast < size_t >(end - cur_) ,
94- static_cast < size_t >(end_ - cur_) );
112+ padding + size ,
113+ available );
95114 return nullptr ;
96115 }
116+ uint8_t * end = start + size;
97117
98118 // Otherwise, record how many bytes were used, advance cur_ to the new end,
99119 // and then return start. Note that the number of bytes used is (end - cur_)
@@ -144,8 +164,9 @@ class MemoryAllocator {
144164 if (overflow) {
145165 ET_LOG (
146166 Error,
147- " Failed to allocate list of type %zu: size * sizeof(T) overflowed" ,
148- size);
167+ " Failed to allocate list: size(%zu) * sizeof(T)(%zu) overflowed" ,
168+ size,
169+ sizeof (T));
149170 return nullptr ;
150171 }
151172 return static_cast <T*>(this ->allocate (bytes_size, alignment));
@@ -171,8 +192,6 @@ class MemoryAllocator {
171192 prof_id_ = EXECUTORCH_TRACK_ALLOCATOR (name);
172193 }
173194
174- virtual ~MemoryAllocator () {}
175-
176195 protected:
177196 /* *
178197 * Returns the profiler ID for this allocator.
@@ -184,21 +203,18 @@ class MemoryAllocator {
184203 /* *
185204 * Returns true if the value is an integer power of 2.
186205 */
187- static bool isPowerOf2 (size_t value) {
188- return value > 0 && (value & ~ (value - 1 )) == value ;
206+ static constexpr bool isPowerOf2 (size_t value) {
207+ return value && ! (value & (value - 1 ));
189208 }
190209
191210 /* *
192211 * Returns the next alignment for a given pointer.
193212 */
194- static uint8_t * alignPointer (void * ptr, size_t alignment) {
195- intptr_t addr = reinterpret_cast <intptr_t >(ptr);
196- if ((addr & (alignment - 1 )) == 0 ) {
197- // Already aligned.
198- return reinterpret_cast <uint8_t *>(ptr);
199- }
200- addr = (addr | (alignment - 1 )) + 1 ;
201- return reinterpret_cast <uint8_t *>(addr);
213+ static inline uint8_t * alignPointer (void * ptr, size_t alignment) {
214+ uintptr_t address = reinterpret_cast <uintptr_t >(ptr);
215+ uintptr_t mask = static_cast <uintptr_t >(alignment - 1 );
216+ address = (address + mask) & ~mask;
217+ return reinterpret_cast <uint8_t *>(address);
202218 }
203219
204220 private:
0 commit comments