Skip to content

Commit

Permalink
Add support for setting the default allocator and deallocator functio…
Browse files Browse the repository at this point in the history
…ns in Halide::Runtime::Buffer. (#8132)
  • Loading branch information
mcourteaux authored Mar 5, 2024
1 parent 7636c44 commit 8b3312c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/runtime/HalideBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ struct AllInts<float, Args...> : std::false_type {};
template<typename... Args>
struct AllInts<double, Args...> : std::false_type {};

// A helper to detect if there are any zeros in a container
namespace Internal {
// A helper to detect if there are any zeros in a container
template<typename Container>
bool any_zero(const Container &c) {
for (int i : c) {
Expand All @@ -153,6 +153,11 @@ bool any_zero(const Container &c) {
}
return false;
}

struct DefaultAllocatorFns {
static inline void *(*default_allocate_fn)(size_t) = nullptr;
static inline void (*default_deallocate_fn)(void *) = nullptr;
};
} // namespace Internal

/** A struct acting as a header for allocations owned by the Buffer
Expand Down Expand Up @@ -711,6 +716,13 @@ class Buffer {
}

public:
static void set_default_allocate_fn(void *(*allocate_fn)(size_t)) {
Internal::DefaultAllocatorFns::default_allocate_fn = allocate_fn;
}
static void set_default_deallocate_fn(void (*deallocate_fn)(void *)) {
Internal::DefaultAllocatorFns::default_deallocate_fn = deallocate_fn;
}

/** Determine if a Buffer<T, Dims, InClassDimStorage> can be constructed from some other Buffer type.
* If this can be determined at compile time, fail with a static assert; otherwise
* return a boolean based on runtime typing. */
Expand Down Expand Up @@ -893,7 +905,7 @@ class Buffer {

#if HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC
// Only use aligned_alloc() if no custom allocators are specified.
if (!allocate_fn && !deallocate_fn) {
if (!allocate_fn && !deallocate_fn && !Internal::DefaultAllocatorFns::default_allocate_fn && !Internal::DefaultAllocatorFns::default_deallocate_fn) {
// As a practical matter, sizeof(AllocationHeader) is going to be no more than 16 bytes
// on any supported platform, so we will just overallocate by 'alignment'
// so that the user storage also starts at an aligned point. This is a bit
Expand All @@ -908,10 +920,16 @@ class Buffer {
// else fall thru
#endif
if (!allocate_fn) {
allocate_fn = malloc;
allocate_fn = Internal::DefaultAllocatorFns::default_allocate_fn;
if (!allocate_fn) {
allocate_fn = malloc;
}
}
if (!deallocate_fn) {
deallocate_fn = free;
deallocate_fn = Internal::DefaultAllocatorFns::default_deallocate_fn;
if (!deallocate_fn) {
deallocate_fn = free;
}
}

static_assert(sizeof(AllocationHeader) <= alignment);
Expand Down
33 changes: 33 additions & 0 deletions test/correctness/halide_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@

using namespace Halide::Runtime;

static void *my_malloced_addr = nullptr;
static int my_malloc_count = 0;
static void *my_freed_addr = nullptr;
static int my_free_count = 0;
void *my_malloc(size_t size) {
void *ptr = malloc(size);
my_malloced_addr = ptr;
my_malloc_count++;
return ptr;
}
void my_free(void *ptr) {
my_freed_addr = ptr;
my_free_count++;
free(ptr);
}

template<typename T1, typename T2>
void check_equal_shape(const Buffer<T1> &a, const Buffer<T2> &b) {
if (a.dimensions() != b.dimensions()) abort();
Expand Down Expand Up @@ -515,6 +531,23 @@ int main(int argc, char **argv) {
assert(b.dim(3).stride() == b2.dim(3).stride());
}

{
// Test setting default allocate and deallocate functions.
Buffer<>::set_default_allocate_fn(my_malloc);
Buffer<>::set_default_deallocate_fn(my_free);

assert(my_malloc_count == 0);
assert(my_free_count == 0);
auto b = Buffer<uint8_t, 2>(5, 4).fill(1);
assert(my_malloced_addr != nullptr && my_malloced_addr < b.data());
assert(my_malloc_count == 1);
assert(my_free_count == 0);
b.deallocate();
assert(my_malloc_count == 1);
assert(my_free_count == 1);
assert(my_malloced_addr == my_freed_addr);
}

printf("Success!\n");
return 0;
}

0 comments on commit 8b3312c

Please sign in to comment.