Skip to content

Commit f3e815e

Browse files
authored
Merge branch 'master' into fix_tf_invalid_output_delete
2 parents aa0024a + cdab22e commit f3e815e

File tree

15 files changed

+322
-69
lines changed

15 files changed

+322
-69
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ ENDIF()
160160
IF(BUILD_ORT)
161161
FIND_LIBRARY(ORT_LIBRARIES NAMES onnxruntime
162162
PATHS ${depsAbs}/onnxruntime/lib)
163+
ADD_SUBDIRECTORY(src/backends/onnx_allocator)
163164
MESSAGE(STATUS "Found ONNXRuntime Libraries: \"${ORT_LIBRARIES}\")")
164165
IF (NOT ORT_LIBRARIES)
165166
MESSAGE(FATAL_ERROR "Could not find ONNXRuntime")
@@ -293,6 +294,7 @@ ENDIF()
293294

294295
IF(BUILD_ORT)
295296
ADD_LIBRARY(redisai_onnxruntime SHARED $<TARGET_OBJECTS:redisai_onnxruntime_obj>)
297+
TARGET_LINK_LIBRARIES(redisai_onnxruntime onnx_allocator ${ORT_LIBRARIES})
296298
TARGET_LINK_LIBRARIES(redisai_onnxruntime ${ORT_LIBRARIES})
297299
SET_TARGET_PROPERTIES(redisai_onnxruntime PROPERTIES PREFIX "")
298300
SET_TARGET_PROPERTIES(redisai_onnxruntime PROPERTIES SUFFIX ".so")

src/backends/backends.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ int RAI_ExportFunc(const char *func_name, void **targetFuncPtr) {
5050
*targetFuncPtr = Config_GetModelExecutionTimeout;
5151
} else if (strcmp("GetThreadsCount", func_name) == 0) {
5252
*targetFuncPtr = BGWorker_GetThreadsCount;
53+
} else if (strcmp("GetBackendMemoryLimit", func_name) == 0) {
54+
*targetFuncPtr = Config_GetBackendMemoryLimit;
5355

5456
// Export RedisAI low level API functions.
5557
} else if (strcmp("RedisAI_InitError", func_name) == 0) {

src/backends/backends_api.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,20 @@ BACKENDS_API uintptr_t (*RedisAI_GetThreadsCount)(void);
3737
BACKENDS_API long long (*RedisAI_GetNumThreadsPerQueue)(void);
3838

3939
/**
40-
* @return The maximal number of milliseconds that a model run session should run
40+
* @return The maximum number of milliseconds that a model run session should run
4141
* before it is terminated forcefully (load time config).
42-
* Currently supported only fo onnxruntime backend.
42+
* Currently supported only for onnxruntime backend.
4343
*/
4444
BACKENDS_API long long (*RedisAI_GetModelExecutionTimeout)(void);
4545

46+
/**
47+
* @return The maximum number of memory (in MB) that a backend can consume
48+
* for creating and running inference sessions. When memory limit is exceeded, operation
49+
* is not permitted and an error is returned.
50+
* Currently supported only for onnxruntime backend.
51+
*/
52+
BACKENDS_API long long (*RedisAI_GetMemoryLimit)(void);
53+
4654
/**
4755
* The following functions are part of RedisAI low level API (the full low level
4856
* API is defined in redisai.h). For every function below named "RedisAI_X", its
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
add_library(onnx_allocator STATIC onnx_allocator.cpp)
2+
target_link_libraries(onnx_allocator "${ONNX_LIBRARIES}")
3+
set_property(TARGET onnx_allocator PROPERTY CXX_STANDARD 14)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include "onnx_allocator.h"
2+
#include "../onnxruntime.h"
3+
#include "onnxruntime_cxx_api.h"
4+
#include <atomic>
5+
6+
struct RAIOrtAllocator : OrtAllocator {
7+
RAIOrtAllocator();
8+
~RAIOrtAllocator();
9+
RAIOrtAllocator(const RAIOrtAllocator&) = delete;
10+
RAIOrtAllocator& operator=(const RAIOrtAllocator&) = delete;
11+
12+
void* Alloc(size_t size);
13+
void Free(void* p);
14+
const OrtMemoryInfo* Info() const;
15+
unsigned long long NumAllocatorAccess() const;
16+
unsigned long long MemoryInUse() const;
17+
void SetMemoryLimit(unsigned long long max_memory);
18+
static RAIOrtAllocator *GetInstance();
19+
20+
private:
21+
std::atomic<unsigned long long> memory_inuse{0};
22+
std::atomic<unsigned long long> num_allocator_access{0};
23+
unsigned long long memory_limit = 0;
24+
OrtMemoryInfo* cpu_memory_info;
25+
static RAIOrtAllocator* allocator_instance;
26+
};
27+
28+
RAIOrtAllocator* RAIOrtAllocator::allocator_instance = nullptr;
29+
30+
RAIOrtAllocator::RAIOrtAllocator() {
31+
OrtAllocator::version = ORT_API_VERSION;
32+
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<RAIOrtAllocator*>(this_)->Alloc(size); };
33+
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<RAIOrtAllocator*>(this_)->Free(p); };
34+
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const RAIOrtAllocator*>(this_)->Info(); };
35+
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info));
36+
RAIOrtAllocator::allocator_instance = this;
37+
}
38+
39+
RAIOrtAllocator::~RAIOrtAllocator() {
40+
Ort::GetApi().ReleaseMemoryInfo(cpu_memory_info);
41+
}
42+
43+
void* RAIOrtAllocator::Alloc(size_t size) {
44+
// Allocate an additional 63 bytes to ensure that we can return an address which is
45+
// 64-byte aligned, and an additional space in the size of a pointer to store
46+
// the address that RedisModule_Alloc returns.
47+
int offset = 63 + sizeof(void *);
48+
void *allocated_address = (void *)RedisModule_Alloc(size + offset);
49+
size_t allocated_size = RedisModule_MallocSize(allocated_address);
50+
// Update the total number of bytes that onnx is using and the number of accesses
51+
// that onnx made to the allocator.
52+
size_t cur_memory = memory_inuse.load();
53+
if (memory_limit && cur_memory + allocated_size > memory_limit) {
54+
RedisModule_Free(allocated_address);
55+
throw Ort::Exception("Onnxruntime memory limit exceeded, memory allocation failed.", ORT_RUNTIME_EXCEPTION);
56+
}
57+
memory_inuse.fetch_add(allocated_size);
58+
num_allocator_access.fetch_add(1);
59+
// This operation guarantees that "aligned_address" is the closest 64-aligned address to ("allocated_address"+size_t).
60+
void **aligned_address = (void **)(((size_t)(allocated_address) + offset) & (~63));
61+
// This stores the address "allocated_address" right before "aligned_address" (so we can retrieve it when we free).
62+
aligned_address[-1] = allocated_address;
63+
return aligned_address;
64+
}
65+
66+
void RAIOrtAllocator::Free(void* p) {
67+
if (p == nullptr) {
68+
return;
69+
}
70+
// Retrieve the address that we originally received from RedisModule_Alloc
71+
// (this is the address that we need to sent to RedisModule_Free).
72+
void *allocated_address = ((void **)p)[-1];
73+
size_t allocated_size = RedisModule_MallocSize(allocated_address);
74+
// Update the total number of bytes that onnx is using and the number of accesses
75+
// that onnx made to the allocator.
76+
memory_inuse.fetch_sub(allocated_size);
77+
num_allocator_access.fetch_add(1);
78+
RedisModule_Free(allocated_address);
79+
}
80+
81+
const OrtMemoryInfo* RAIOrtAllocator::Info() const {
82+
return cpu_memory_info;
83+
}
84+
85+
unsigned long long RAIOrtAllocator::NumAllocatorAccess() const {
86+
return num_allocator_access.load();
87+
}
88+
89+
unsigned long long RAIOrtAllocator::MemoryInUse() const {
90+
return memory_inuse.load();
91+
}
92+
93+
void RAIOrtAllocator::SetMemoryLimit(unsigned long long max_memory) {
94+
// max_memory is given in MB
95+
memory_limit = 1000000*max_memory;
96+
}
97+
98+
RAIOrtAllocator *RAIOrtAllocator::GetInstance() {
99+
return RAIOrtAllocator::allocator_instance;
100+
}
101+
102+
OrtAllocator *CreateCustomAllocator(unsigned long long max_memory) {
103+
auto *allocator = new RAIOrtAllocator();
104+
allocator->SetMemoryLimit(max_memory);
105+
return allocator;
106+
}
107+
108+
unsigned long long RAI_GetMemoryInfoORT() {
109+
return RAIOrtAllocator::GetInstance()->MemoryInUse();
110+
}
111+
112+
unsigned long long RAI_GetMemoryAccessORT() {
113+
return RAIOrtAllocator::GetInstance()->NumAllocatorAccess();
114+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "onnxruntime_c_api.h"
4+
5+
#ifdef __cplusplus
6+
extern "C" {
7+
#endif
8+
9+
OrtAllocator *CreateCustomAllocator(unsigned long long max_memory);
10+
11+
unsigned long long RAI_GetMemoryInfoORT();
12+
13+
unsigned long long RAI_GetMemoryAccessORT();
14+
15+
#ifdef __cplusplus
16+
}
17+
#endif

src/backends/onnxruntime.c

Lines changed: 5 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "util/arr.h"
77
#include "backends/onnxruntime.h"
88
#include "redis_ai_objects/tensor.h"
9+
#include "onnx_allocator/onnx_allocator.h"
910

1011
#include "onnxruntime_c_api.h"
1112
#include "backends_api.h"
@@ -21,63 +22,7 @@ OrtEnv *env = NULL;
2122
// For model that run on GPU, onnx will not use the custom allocator (redis allocator), but
2223
// the onnx allocator for GPU. But for the auxiliary allocations of the input and output names,
2324
// we will use the custom global allocator for models that run on GPU as well.
24-
OrtMemoryInfo *mem_info = NULL;
2525
OrtAllocator *global_allocator = NULL;
26-
unsigned long long OnnxMemory = 0;
27-
unsigned long long OnnxMemoryAccessCounter = 0;
28-
29-
const OrtMemoryInfo *AllocatorInfo(const OrtAllocator *allocator) {
30-
(void)allocator;
31-
const OrtApi *ort = OrtGetApiBase()->GetApi(1);
32-
if (mem_info != NULL) {
33-
return mem_info;
34-
}
35-
if (ort->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &mem_info) != NULL) {
36-
return NULL;
37-
}
38-
return mem_info;
39-
}
40-
41-
// Allocate address with 64-byte alignment to cope with onnx optimizations.
42-
void *AllocatorAlloc(OrtAllocator *ptr, size_t size) {
43-
44-
(void)ptr;
45-
// Allocate an additional 63 bytes to ensure that we can return an address which is
46-
// 64-byte aligned, and an additional space in the size of a pointer to store
47-
// the address that RedisModule_Alloc returns.
48-
int offset = 63 + sizeof(void *);
49-
void *allocated_address = (void *)RedisModule_Alloc(size + offset);
50-
size_t allocated_size = RedisModule_MallocSize(allocated_address);
51-
// Update the total number of bytes that onnx is using and the number of accesses
52-
// that onnx made to the allocator.
53-
atomic_fetch_add(&OnnxMemory, allocated_size);
54-
atomic_fetch_add(&OnnxMemoryAccessCounter, 1);
55-
// This operation guarantees that p2 is the closest 64-aligned address to (p1+size_t).
56-
void **aligned_address = (void **)(((size_t)(allocated_address) + offset) & (~63));
57-
// This stores the address p1 right before p2 (so we can retrieve it when we free).
58-
aligned_address[-1] = allocated_address;
59-
return aligned_address;
60-
}
61-
62-
void AllocatorFree(OrtAllocator *ptr, void *aligned_address) {
63-
(void)ptr;
64-
if (aligned_address == NULL) {
65-
return;
66-
}
67-
// Retrieve the address that we originally received from RedisModule_Alloc
68-
// (this is the address that we need to sent to RedisModule_Free).
69-
void *allocated_address = ((void **)aligned_address)[-1];
70-
size_t allocated_size = RedisModule_MallocSize(allocated_address);
71-
// Update the total number of bytes that onnx is using and the number of accesses
72-
// that onnx made to the allocator.
73-
atomic_fetch_sub(&OnnxMemory, allocated_size);
74-
atomic_fetch_add(&OnnxMemoryAccessCounter, 1);
75-
return RedisModule_Free(allocated_address);
76-
}
77-
78-
unsigned long long RAI_GetMemoryInfoORT() { return OnnxMemory; }
79-
80-
unsigned long long RAI_GetMemoryAccessORT() { return OnnxMemoryAccessCounter; }
8126

8227
int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) {
8328
// Export redis callbacks.
@@ -95,6 +40,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) {
9540
get_api_fn("GetThreadId", ((void **)&RedisAI_GetThreadId));
9641
get_api_fn("GetNumThreadsPerQueue", ((void **)&RedisAI_GetNumThreadsPerQueue));
9742
get_api_fn("GetModelExecutionTimeout", ((void **)&RedisAI_GetModelExecutionTimeout));
43+
get_api_fn("GetBackendMemoryLimit", ((void **)&RedisAI_GetMemoryLimit));
9844
get_api_fn("GetThreadsCount", ((void **)&RedisAI_GetThreadsCount));
9945

10046
// Create a global array of onnx runSessions, with an entry for every working thread.
@@ -389,8 +335,9 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
389335
// allocating buffers when creating and running models that run on CPU, and for allocations of
390336
// models inputs and outputs names (for both models that run on CPU and GPU)
391337
if (env == NULL) {
392-
ONNX_VALIDATE_STATUS(ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env))
393-
ONNX_VALIDATE_STATUS(ort->GetAllocatorWithDefaultOptions(&global_allocator));
338+
ONNX_VALIDATE_STATUS(ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "RedisAI", &env))
339+
global_allocator = CreateCustomAllocator(RedisAI_GetMemoryLimit());
340+
ONNX_VALIDATE_STATUS(ort->RegisterAllocator(env, global_allocator))
394341
}
395342

396343
ONNX_VALIDATE_STATUS(ort->CreateSessionOptions(&session_options))

src/backends/onnxruntime.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
#include "redis_ai_objects/model.h"
66
#include "execution/execution_contexts/execution_ctx.h"
77

8-
unsigned long long RAI_GetMemoryInfoORT(void);
9-
10-
unsigned long long RAI_GetMemoryAccessORT(void);
11-
128
int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **));
139

1410
RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,

src/config/config.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ long long ThreadPoolSizePerQueue = 1; // Number of working threads for devi
1616

1717
long long ModelExecutionTimeout = 5000; // The maximum time in milliseconds
1818
// before killing onnx run session.
19+
long long BackendMemoryLimit = 0; // The maximum amount of memory in MB
20+
// that backend is allowed to consume.
1921

2022
static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, const char *val,
2123
RedisModuleString *rsval) {
@@ -56,6 +58,11 @@ static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, cons
5658
if (ret == REDISMODULE_OK) {
5759
RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_EXECUTION_TIMEOUT, val);
5860
}
61+
} else if (strcasecmp((key), "BACKEND_MEMORY_LIMIT") == 0) {
62+
ret = Config_SetBackendMemoryLimit(rsval);
63+
if (ret == REDISMODULE_OK) {
64+
RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_BACKEND_MEMORY_LIMIT, val);
65+
}
5966
} else if (strcasecmp((key), "BACKENDSPATH") == 0) {
6067
// already taken care of
6168
} else {
@@ -74,6 +81,8 @@ long long Config_GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; }
7481

7582
long long Config_GetModelExecutionTimeout() { return ModelExecutionTimeout; }
7683

84+
long long Config_GetBackendMemoryLimit() { return BackendMemoryLimit; }
85+
7786
char *Config_GetBackendsPath() { return BackendsPath; }
7887

7988
int Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
@@ -160,6 +169,16 @@ int Config_SetModelExecutionTimeout(RedisModuleString *timeout) {
160169
return REDISMODULE_OK;
161170
}
162171

172+
int Config_SetBackendMemoryLimit(RedisModuleString *memory_limit) {
173+
long long val;
174+
int result = RedisModule_StringToLongLong(memory_limit, &val);
175+
if (result != REDISMODULE_OK || val <= 0) {
176+
return REDISMODULE_ERR;
177+
}
178+
BackendMemoryLimit = val;
179+
return REDISMODULE_OK;
180+
}
181+
163182
int Config_SetLoadTimeParams(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc) {
164183
if (argc > 0 && argc % 2 != 0) {
165184
RedisModule_Log(ctx, "warning",

src/config/config.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device;
2626
#define REDISAI_INFOMSG_INTER_OP_PARALLELISM "Setting INTER_OP_PARALLELISM parameter to"
2727
#define REDISAI_INFOMSG_MODEL_CHUNK_SIZE "Setting MODEL_CHUNK_SIZE parameter to"
2828
#define REDISAI_INFOMSG_MODEL_EXECUTION_TIMEOUT "Setting MODEL_EXECUTION_TIMEOUT parameter to"
29+
#define REDISAI_INFOMSG_BACKEND_MEMORY_LIMIT "Setting BACKEND_MEMORY_LIMIT parameter to"
2930

3031
/**
3132
* Get number of threads used for parallelism between independent operations, by
@@ -56,6 +57,13 @@ long long Config_GetNumThreadsPerQueue(void);
5657
*/
5758
long long Config_GetModelExecutionTimeout(void);
5859

60+
/**
61+
* @return Memory limit in MB for backend. This is the maximum amount of memory
62+
* that can be consumed by the backend for creating and running sessions.
63+
* Currently supported only for onnxruntime backend.
64+
*/
65+
long long Config_GetBackendMemoryLimit(void);
66+
5967
/**
6068
* @return Returns the backends path string.
6169
*/
@@ -113,11 +121,19 @@ int Config_SetModelChunkSize(RedisModuleString *chunk_size_string);
113121

114122
/**
115123
* Set the maximum time in ms that onnx backend allow running a model.
116-
* @param onnx_max_runtime - string containing the max runtime (in ms)
124+
* @param timeout - string containing the max runtime (in ms)
117125
* @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed
118126
*/
119127
int Config_SetModelExecutionTimeout(RedisModuleString *timeout);
120128

129+
/**
130+
* Set the memory limit in MB for backends allocations.
131+
* @param memory_limit - maximum memory consumption by backend. If values is zero,
132+
* there will be no enforcement of any memory limit.
133+
* @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed
134+
*/
135+
int Config_SetBackendMemoryLimit(RedisModuleString *memory_limit);
136+
121137
/**
122138
* Load time configuration parser
123139
* @param ctx Context in which Redis modules operate

src/redisai.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,7 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) {
11671167
Config_GetBackendsIntraOpParallelism());
11681168
RedisModule_InfoAddFieldLongLong(ctx, "model_execution_timeout",
11691169
Config_GetModelExecutionTimeout());
1170+
RedisModule_InfoAddFieldLongLong(ctx, "backend_memory_limit", Config_GetBackendMemoryLimit());
11701171
_moduleInfo_getBackendsInfo(ctx);
11711172

11721173
struct rusage self_ru, c_ru;

0 commit comments

Comments
 (0)