forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCAllocator.c
42 lines (28 loc) · 1.01 KB
/
THCAllocator.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include "THCAllocator.h"
static void *THCudaHostAllocator_alloc(void* ctx, long size) {
void* ptr;
if (size < 0) THError("Invalid memory size: %ld", size);
if (size == 0) return NULL;
THCudaCheck(cudaMallocHost(&ptr, size));
return ptr;
}
static void THCudaHostAllocator_free(void* ctx, void* ptr) {
if (!ptr) return;
THCudaCheck(cudaFreeHost(ptr));
}
static void *THCudaHostAllocator_realloc(void* ctx, void* ptr, long size) {
if (size < 0) THError("Invalid memory size: %ld", size);
THCudaHostAllocator_free(ctx, ptr);
if (size == 0) return NULL;
THCudaCheck(cudaMallocHost(&ptr, size));
return ptr;
}
void THCAllocator_init(THCState *state) {
state->cudaHostAllocator = (THAllocator*)malloc(sizeof(THAllocator));
state->cudaHostAllocator->malloc = &THCudaHostAllocator_alloc;
state->cudaHostAllocator->realloc = &THCudaHostAllocator_realloc;
state->cudaHostAllocator->free = &THCudaHostAllocator_free;
}
void THCAllocator_shutdown(THCState *state) {
free(state->cudaHostAllocator);
}