88
99#pragma once
1010
11- #include < stdio.h>
1211#include < cinttypes>
13- #include < cstdint>
1412
1513#include < c10/util/safe_numerics.h>
1614
1715#include < executorch/runtime/core/error.h>
1816#include < executorch/runtime/platform/assert.h>
19- #include < executorch/runtime/platform/compiler.h>
20- #include < executorch/runtime/platform/log.h>
2117#include < executorch/runtime/platform/profiler.h>
2218
2319namespace executorch {
@@ -59,9 +55,26 @@ class MemoryAllocator {
5955 */
6056 MemoryAllocator (uint32_t size, uint8_t * base_address)
6157 : begin_(base_address),
62- end_ (base_address + size),
58+ end_ (
59+ base_address
60+ ? (UINTPTR_MAX - reinterpret_cast <uintptr_t >(base_address) >=
61+ size
62+ ? base_address + size
63+ : nullptr)
64+ : nullptr),
6365 cur_(base_address),
64- size_(size) {}
66+ size_(size) {
67+ ET_CHECK_MSG (
68+ base_address || size == 0 ,
69+ " Base address is null but size=%" PRIu32,
70+ size);
71+ ET_CHECK_MSG (
72+ !base_address || size == 0 ||
73+ (UINTPTR_MAX - reinterpret_cast <uintptr_t >(base_address) >= size),
74+ " Address space overflow in allocator" );
75+ }
76+
77+ virtual ~MemoryAllocator () = default ;
6578
6679 /* *
6780 * Allocates `size` bytes of memory.
@@ -74,6 +87,10 @@ class MemoryAllocator {
7487 * @retval nullptr Not enough memory, or `alignment` was not a power of 2.
7588 */
7689 virtual void * allocate (size_t size, size_t alignment = kDefaultAlignment ) {
90+ if ET_UNLIKELY (!begin_ || !end_) {
91+ ET_LOG (Error, " allocate() on zero-capacity allocator" );
92+ return nullptr ;
93+ }
7794 if (!isPowerOf2 (alignment)) {
7895 ET_LOG (Error, " Alignment %zu is not a power of 2" , alignment);
7996 return nullptr ;
@@ -82,18 +99,17 @@ class MemoryAllocator {
8299 // The allocation will occupy [start, end), where the start is the next
83100 // position that's a multiple of alignment.
84101 uint8_t * start = alignPointer (cur_, alignment);
85- uint8_t * end = start + size;
86-
87- // If the end of this allocation exceeds the end of this allocator, print
88- // error messages and return nullptr
89- if (end > end_ || end < start) {
102+ size_t padding = static_cast <size_t >(start - cur_);
103+ size_t available = static_cast <size_t >(end_ - cur_);
104+ if ET_UNLIKELY (padding > available || size > available - padding) {
90105 ET_LOG (
91106 Error,
92107 " Memory allocation failed: %zuB requested (adjusted for alignment), %zuB available" ,
93- static_cast < size_t >(end - cur_) ,
94- static_cast < size_t >(end_ - cur_) );
108+ padding + size ,
109+ available );
95110 return nullptr ;
96111 }
112+ uint8_t * end = start + size;
97113
98114 // Otherwise, record how many bytes were used, advance cur_ to the new end,
99115 // and then return start. Note that the number of bytes used is (end - cur_)
@@ -144,8 +160,9 @@ class MemoryAllocator {
144160 if (overflow) {
145161 ET_LOG (
146162 Error,
147- " Failed to allocate list of type %zu: size * sizeof(T) overflowed" ,
148- size);
163+ " Failed to allocate list: size(%zu) * sizeof(T)(%zu) overflowed" ,
164+ size,
165+ sizeof (T));
149166 return nullptr ;
150167 }
151168 return static_cast <T*>(this ->allocate (bytes_size, alignment));
@@ -171,8 +188,6 @@ class MemoryAllocator {
171188 prof_id_ = EXECUTORCH_TRACK_ALLOCATOR (name);
172189 }
173190
174- virtual ~MemoryAllocator () {}
175-
176191 protected:
177192 /* *
178193 * Returns the profiler ID for this allocator.
@@ -184,21 +199,18 @@ class MemoryAllocator {
184199 /* *
185200 * Returns true if the value is an integer power of 2.
186201 */
187- static bool isPowerOf2 (size_t value) {
188- return value > 0 && (value & ~ (value - 1 )) == value ;
202+ static constexpr bool isPowerOf2 (size_t value) {
203+ return value && ! (value & (value - 1 ));
189204 }
190205
191206 /* *
192207 * Returns the next alignment for a given pointer.
193208 */
194- static uint8_t * alignPointer (void * ptr, size_t alignment) {
195- intptr_t addr = reinterpret_cast <intptr_t >(ptr);
196- if ((addr & (alignment - 1 )) == 0 ) {
197- // Already aligned.
198- return reinterpret_cast <uint8_t *>(ptr);
199- }
200- addr = (addr | (alignment - 1 )) + 1 ;
201- return reinterpret_cast <uint8_t *>(addr);
209+ static inline uint8_t * alignPointer (void * ptr, size_t alignment) {
210+ uintptr_t address = reinterpret_cast <uintptr_t >(ptr);
211+ uintptr_t mask = static_cast <uintptr_t >(alignment - 1 );
212+ address = (address + mask) & ~mask;
213+ return reinterpret_cast <uint8_t *>(address);
202214 }
203215
204216 private:
0 commit comments