99#pragma once
1010
1111#include < stdio.h>
12- #include < cinttypes>
1312#include < cstdint>
1413
1514#include < c10/util/safe_numerics.h>
@@ -59,9 +58,20 @@ class MemoryAllocator {
5958 */
6059 MemoryAllocator (uint32_t size, uint8_t * base_address)
6160 : begin_(base_address),
62- end_ (base_address + size),
61+ end_ (base_address ?
62+ (UINTPTR_MAX - reinterpret_cast <uintptr_t >(base_address) >= size ?
63+ base_address + size : nullptr) : nullptr),
6364 cur_(base_address),
64- size_(size) {}
65+ size_(size) {
66+ ET_CHECK_MSG (base_address || size == 0 , " Base address is null but size=%u" , size);
67+ ET_CHECK_MSG (!base_address || size == 0 ||
68+ (UINTPTR_MAX - reinterpret_cast <uintptr_t >(base_address) >= size),
69+ " Address space overflow in allocator" );
70+ }
71+
72+ MemoryAllocator (const MemoryAllocator&) = delete;
73+ MemoryAllocator& operator =(const MemoryAllocator&) = delete ;
74+ virtual ~MemoryAllocator () = default ;
6575
6676 /* *
6777 * Allocates `size` bytes of memory.
@@ -74,6 +84,10 @@ class MemoryAllocator {
7484 * @retval nullptr Not enough memory, or `alignment` was not a power of 2.
7585 */
7686 virtual void * allocate (size_t size, size_t alignment = kDefaultAlignment ) {
87+ if (ET_UNLIKELY (!begin_ || !end_)) {
88+ ET_LOG (Error, " allocate() on zero-capacity allocator" );
89+ return nullptr ;
90+ }
7791 if (!isPowerOf2 (alignment)) {
7892 ET_LOG (Error, " Alignment %zu is not a power of 2" , alignment);
7993 return nullptr ;
@@ -82,18 +96,17 @@ class MemoryAllocator {
8296 // The allocation will occupy [start, end), where the start is the next
8397 // position that's a multiple of alignment.
8498 uint8_t * start = alignPointer (cur_, alignment);
85- uint8_t * end = start + size;
86-
87- // If the end of this allocation exceeds the end of this allocator, print
88- // error messages and return nullptr
89- if (end > end_ || end < start) {
99+ size_t padding = static_cast <size_t >(start - cur_);
100+ size_t available = static_cast <size_t >(end_ - cur_);
101+ if (ET_UNLIKELY (padding > available || size > available - padding)) {
90102 ET_LOG (
91103 Error,
92104 " Memory allocation failed: %zuB requested (adjusted for alignment), %zuB available" ,
93- static_cast < size_t >(end - cur_) ,
94- static_cast < size_t >(end_ - cur_) );
105+ padding + size ,
106+ available );
95107 return nullptr ;
96108 }
109+ uint8_t * end = start + size;
97110
98111 // Otherwise, record how many bytes were used, advance cur_ to the new end,
99112 // and then return start. Note that the number of bytes used is (end - cur_)
@@ -144,8 +157,9 @@ class MemoryAllocator {
144157 if (overflow) {
145158 ET_LOG (
146159 Error,
147- " Failed to allocate list of type %zu: size * sizeof(T) overflowed" ,
148- size);
160+ " Failed to allocate list: size(%zu) * sizeof(T)(%zu) overflowed" ,
161+ size,
162+ sizeof (T));
149163 return nullptr ;
150164 }
151165 return static_cast <T*>(this ->allocate (bytes_size, alignment));
@@ -171,8 +185,6 @@ class MemoryAllocator {
171185 prof_id_ = EXECUTORCH_TRACK_ALLOCATOR (name);
172186 }
173187
174- virtual ~MemoryAllocator () {}
175-
176188 protected:
177189 /* *
178190 * Returns the profiler ID for this allocator.
@@ -184,21 +196,18 @@ class MemoryAllocator {
184196 /* *
185197 * Returns true if the value is an integer power of 2.
186198 */
187- static bool isPowerOf2 (size_t value) {
188- return value > 0 && (value & ~ (value - 1 )) == value ;
199+ static constexpr bool isPowerOf2 (size_t value) {
200+ return value && ! (value & (value - 1 ));
189201 }
190202
191203 /* *
192204 * Returns the next alignment for a given pointer.
193205 */
194- static uint8_t * alignPointer (void * ptr, size_t alignment) {
195- intptr_t addr = reinterpret_cast <intptr_t >(ptr);
196- if ((addr & (alignment - 1 )) == 0 ) {
197- // Already aligned.
198- return reinterpret_cast <uint8_t *>(ptr);
199- }
200- addr = (addr | (alignment - 1 )) + 1 ;
201- return reinterpret_cast <uint8_t *>(addr);
206+ static inline uint8_t * alignPointer (void * ptr, size_t alignment) {
207+ uintptr_t address = reinterpret_cast <uintptr_t >(ptr);
208+ uintptr_t mask = static_cast <uintptr_t >(alignment - 1 );
209+ address = (address + mask) & ~mask;
210+ return reinterpret_cast <uint8_t *>(address);
202211 }
203212
204213 private:
0 commit comments