@@ -398,30 +398,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
398
398
goto cleanup ;
399
399
}
400
400
401
- TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
402
- uint8_t config [4 ] = {0x32 , 0x02 , 0x20 , 0x01 };
403
- TFE_ContextOptionsSetConfig (context_opts , (void * )config , 4 , status );
404
-
405
- TFE_ContextOptionsSetAsync (context_opts , 0 );
406
- TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
407
-
408
- TFE_Context * context = TFE_NewContext (context_opts , status );
409
- if (TF_GetCode (status ) != TF_OK ) {
410
- RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
411
- goto cleanup ;
412
- }
413
-
414
- TFE_ContextAddFunction (context , function , status );
415
- if (TF_GetCode (status ) != TF_OK ) {
416
- RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
417
- goto cleanup ;
418
- }
419
-
420
- TFE_DeleteContextOptions (context_opts );
421
-
422
- TF_DeleteStatus (status );
423
-
424
- #if 0
425
401
// For setting config options in session from the C API see:
426
402
// https://github.com/tensorflow/tensorflow/issues/13853
427
403
// import tensorflow as tf
@@ -430,33 +406,35 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
430
406
// result = list(map(hex, serialized))
431
407
// print(result)
432
408
409
+ TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
410
+
433
411
if (device == RAI_DEVICE_CPU ) {
434
412
// Set number of GPU to 0 with
435
413
// config.device_count = {'GPU': 0}
436
414
uint8_t config [] = {0x0a , 0x07 , 0x0a , 0x03 , 0x47 , 0x50 , 0x55 , 0x10 , 0x00 };
437
- TF_SetConfig ( sessionOptions , (void * )config , sizeof (config ), optionsStatus );
415
+ TFE_ContextOptionsSetConfig ( context_opts , (void * )config , sizeof (config ), status );
438
416
439
- if (TF_GetCode (optionsStatus ) != TF_OK ) {
440
- RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (optionsStatus )));
417
+ if (TF_GetCode (status ) != TF_OK ) {
418
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
441
419
goto cleanup ;
442
420
}
443
421
444
422
if (opts .backends_intra_op_parallelism > 0 ) {
445
423
uint8_t proto [] = {0x10 , (uint8_t )opts .backends_intra_op_parallelism };
446
- TF_SetConfig ( sessionOptions , proto , sizeof (proto ), optionsStatus );
447
- if (TF_GetCode (optionsStatus ) != TF_OK ) {
424
+ TFE_ContextOptionsSetConfig ( context_opts , proto , sizeof (proto ), status );
425
+ if (TF_GetCode (status ) != TF_OK ) {
448
426
RAI_SetError (error , RAI_EMODELCONFIGURE ,
449
- RedisModule_Strdup (TF_Message (optionsStatus )));
427
+ RedisModule_Strdup (TF_Message (status )));
450
428
goto cleanup ;
451
429
}
452
430
}
453
431
454
432
if (opts .backends_inter_op_parallelism > 0 ) {
455
433
uint8_t proto1 [] = {0x28 , (uint8_t )opts .backends_inter_op_parallelism };
456
- TF_SetConfig ( sessionOptions , proto1 , sizeof (proto1 ), optionsStatus );
457
- if (TF_GetCode (optionsStatus ) != TF_OK ) {
434
+ TFE_ContextOptionsSetConfig ( context_opts , proto1 , sizeof (proto1 ), status );
435
+ if (TF_GetCode (status ) != TF_OK ) {
458
436
RAI_SetError (error , RAI_EMODELCONFIGURE ,
459
- RedisModule_Strdup (TF_Message (optionsStatus )));
437
+ RedisModule_Strdup (TF_Message (status )));
460
438
goto cleanup ;
461
439
}
462
440
}
@@ -465,23 +443,39 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
465
443
// Set
466
444
// config.gpu_options.allow_growth = True
467
445
uint8_t config [4 ] = {0x32 , 0x02 , 0x20 , 0x01 };
468
- TF_SetConfig ( sessionOptions , (void * )config , 4 , optionsStatus );
446
+ TFE_ContextOptionsSetConfig ( context_opts , (void * )config , 4 , status );
469
447
} else {
470
448
// Set
471
449
// config.gpu_options.allow_growth = True
472
450
// config.gpu_options.visible_device_list = '<deviceid>'
473
451
uint8_t config [7 ] = {0x32 , 0x05 , 0x20 , 0x01 , 0x2a , 0x01 , 0x30 };
474
452
config [6 ] += (uint8_t )deviceid ;
475
- TF_SetConfig ( sessionOptions , (void * )config , 7 , optionsStatus );
453
+ TFE_ContextOptionsSetConfig ( context_opts , (void * )config , 7 , status );
476
454
}
477
455
}
478
456
479
- TF_Status * deviceListStatus = TF_NewStatus ();
480
- TF_DeviceList * deviceList = TF_SessionListDevices (session , deviceListStatus );
457
+ TFE_ContextOptionsSetAsync (context_opts , 0 );
458
+ TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
459
+
460
+ TFE_Context * context = TFE_NewContext (context_opts , status );
461
+ if (TF_GetCode (status ) != TF_OK ) {
462
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
463
+ goto cleanup ;
464
+ }
465
+
466
+ TFE_ContextAddFunction (context , function , status );
467
+ if (TF_GetCode (status ) != TF_OK ) {
468
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
469
+ goto cleanup ;
470
+ }
471
+
472
+ TFE_DeleteContextOptions (context_opts );
473
+
474
+ TF_DeviceList * deviceList = TFE_ContextListDevices (context , status );
481
475
const int num_devices = TF_DeviceListCount (deviceList );
482
476
int foundNoGPU = 1 ;
483
477
for (int i = 0 ; i < num_devices ; ++ i ) {
484
- const char * device_type = TF_DeviceListType (deviceList , i , deviceListStatus );
478
+ const char * device_type = TF_DeviceListType (deviceList , i , status );
485
479
int cmp = strcmp (device_type , "GPU" );
486
480
if (cmp == 0 ) {
487
481
foundNoGPU = 0 ;
@@ -491,17 +485,18 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
491
485
if (foundNoGPU == 1 && device == RAI_DEVICE_GPU ) {
492
486
RAI_SetError (error , RAI_EMODELCREATE , "ERR GPU requested but TF couldn't find CUDA" );
493
487
TF_DeleteDeviceList (deviceList );
494
- TF_DeleteStatus (deviceListStatus );
488
+ TF_DeleteStatus (status );
495
489
goto cleanup ;
496
490
}
497
491
TF_DeleteDeviceList (deviceList );
498
- TF_DeleteStatus (deviceListStatus );
499
492
500
- if (TF_GetCode (sessionStatus ) != TF_OK ) {
493
+ if (TF_GetCode (status ) != TF_OK ) {
501
494
RAI_SetError (error , RAI_EMODELCREATE , RedisModule_Strdup (TF_Message (status )));
502
495
goto cleanup ;
503
496
}
504
- #endif
497
+
498
+ TF_DeleteStatus (status );
499
+
505
500
char * * inputs_ = array_new (char * , ninputs );
506
501
for (long long i = 0 ; i < ninputs ; i ++ ) {
507
502
inputs_ = array_append (inputs_ , RedisModule_Strdup (inputs [i ]));
0 commit comments