Skip to content

Commit baf15d4

Browse files
committed
Add TFE configuration
1 parent 1e2cb52 commit baf15d4

File tree

1 file changed

+37
-42
lines changed

1 file changed

+37
-42
lines changed

src/backends/tensorflow.c

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -398,30 +398,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
398398
goto cleanup;
399399
}
400400

401-
TFE_ContextOptions *context_opts = TFE_NewContextOptions();
402-
uint8_t config[4] = {0x32, 0x02, 0x20, 0x01};
403-
TFE_ContextOptionsSetConfig(context_opts, (void *)config, 4, status);
404-
405-
TFE_ContextOptionsSetAsync(context_opts, 0);
406-
TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
407-
408-
TFE_Context *context = TFE_NewContext(context_opts, status);
409-
if (TF_GetCode(status) != TF_OK) {
410-
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status)));
411-
goto cleanup;
412-
}
413-
414-
TFE_ContextAddFunction(context, function, status);
415-
if (TF_GetCode(status) != TF_OK) {
416-
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status)));
417-
goto cleanup;
418-
}
419-
420-
TFE_DeleteContextOptions(context_opts);
421-
422-
TF_DeleteStatus(status);
423-
424-
#if 0
425401
// For setting config options in session from the C API see:
426402
// https://github.com/tensorflow/tensorflow/issues/13853
427403
// import tensorflow as tf
@@ -430,33 +406,35 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
430406
// result = list(map(hex, serialized))
431407
// print(result)
432408

409+
TFE_ContextOptions *context_opts = TFE_NewContextOptions();
410+
433411
if (device == RAI_DEVICE_CPU) {
434412
// Set number of GPU to 0 with
435413
// config.device_count = {'GPU': 0}
436414
uint8_t config[] = {0x0a, 0x07, 0x0a, 0x03, 0x47, 0x50, 0x55, 0x10, 0x00};
437-
TF_SetConfig(sessionOptions, (void *)config, sizeof(config), optionsStatus);
415+
TFE_ContextOptionsSetConfig(context_opts, (void *)config, sizeof(config), status);
438416

439-
if (TF_GetCode(optionsStatus) != TF_OK) {
440-
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(optionsStatus)));
417+
if (TF_GetCode(status) != TF_OK) {
418+
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status)));
441419
goto cleanup;
442420
}
443421

444422
if (opts.backends_intra_op_parallelism > 0) {
445423
uint8_t proto[] = {0x10, (uint8_t)opts.backends_intra_op_parallelism};
446-
TF_SetConfig(sessionOptions, proto, sizeof(proto), optionsStatus);
447-
if (TF_GetCode(optionsStatus) != TF_OK) {
424+
TFE_ContextOptionsSetConfig(context_opts, proto, sizeof(proto), status);
425+
if (TF_GetCode(status) != TF_OK) {
448426
RAI_SetError(error, RAI_EMODELCONFIGURE,
449-
RedisModule_Strdup(TF_Message(optionsStatus)));
427+
RedisModule_Strdup(TF_Message(status)));
450428
goto cleanup;
451429
}
452430
}
453431

454432
if (opts.backends_inter_op_parallelism > 0) {
455433
uint8_t proto1[] = {0x28, (uint8_t)opts.backends_inter_op_parallelism};
456-
TF_SetConfig(sessionOptions, proto1, sizeof(proto1), optionsStatus);
457-
if (TF_GetCode(optionsStatus) != TF_OK) {
434+
TFE_ContextOptionsSetConfig(context_opts, proto1, sizeof(proto1), status);
435+
if (TF_GetCode(status) != TF_OK) {
458436
RAI_SetError(error, RAI_EMODELCONFIGURE,
459-
RedisModule_Strdup(TF_Message(optionsStatus)));
437+
RedisModule_Strdup(TF_Message(status)));
460438
goto cleanup;
461439
}
462440
}
@@ -465,23 +443,39 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
465443
// Set
466444
// config.gpu_options.allow_growth = True
467445
uint8_t config[4] = {0x32, 0x02, 0x20, 0x01};
468-
TF_SetConfig(sessionOptions, (void *)config, 4, optionsStatus);
446+
TFE_ContextOptionsSetConfig(context_opts, (void *)config, 4, status);
469447
} else {
470448
// Set
471449
// config.gpu_options.allow_growth = True
472450
// config.gpu_options.visible_device_list = '<deviceid>'
473451
uint8_t config[7] = {0x32, 0x05, 0x20, 0x01, 0x2a, 0x01, 0x30};
474452
config[6] += (uint8_t)deviceid;
475-
TF_SetConfig(sessionOptions, (void *)config, 7, optionsStatus);
453+
TFE_ContextOptionsSetConfig(context_opts, (void *)config, 7, status);
476454
}
477455
}
478456

479-
TF_Status *deviceListStatus = TF_NewStatus();
480-
TF_DeviceList *deviceList = TF_SessionListDevices(session, deviceListStatus);
457+
TFE_ContextOptionsSetAsync(context_opts, 0);
458+
TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
459+
460+
TFE_Context *context = TFE_NewContext(context_opts, status);
461+
if (TF_GetCode(status) != TF_OK) {
462+
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status)));
463+
goto cleanup;
464+
}
465+
466+
TFE_ContextAddFunction(context, function, status);
467+
if (TF_GetCode(status) != TF_OK) {
468+
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status)));
469+
goto cleanup;
470+
}
471+
472+
TFE_DeleteContextOptions(context_opts);
473+
474+
TF_DeviceList *deviceList = TFE_ContextListDevices(context, status);
481475
const int num_devices = TF_DeviceListCount(deviceList);
482476
int foundNoGPU = 1;
483477
for (int i = 0; i < num_devices; ++i) {
484-
const char *device_type = TF_DeviceListType(deviceList, i, deviceListStatus);
478+
const char *device_type = TF_DeviceListType(deviceList, i, status);
485479
int cmp = strcmp(device_type, "GPU");
486480
if (cmp == 0) {
487481
foundNoGPU = 0;
@@ -491,17 +485,18 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
491485
if (foundNoGPU == 1 && device == RAI_DEVICE_GPU) {
492486
RAI_SetError(error, RAI_EMODELCREATE, "ERR GPU requested but TF couldn't find CUDA");
493487
TF_DeleteDeviceList(deviceList);
494-
TF_DeleteStatus(deviceListStatus);
488+
TF_DeleteStatus(status);
495489
goto cleanup;
496490
}
497491
TF_DeleteDeviceList(deviceList);
498-
TF_DeleteStatus(deviceListStatus);
499492

500-
if (TF_GetCode(sessionStatus) != TF_OK) {
493+
if (TF_GetCode(status) != TF_OK) {
501494
RAI_SetError(error, RAI_EMODELCREATE, RedisModule_Strdup(TF_Message(status)));
502495
goto cleanup;
503496
}
504-
#endif
497+
498+
TF_DeleteStatus(status);
499+
505500
char **inputs_ = array_new(char *, ninputs);
506501
for (long long i = 0; i < ninputs; i++) {
507502
inputs_ = array_append(inputs_, RedisModule_Strdup(inputs[i]));

0 commit comments

Comments
 (0)