@@ -515,6 +515,18 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
515
515
}
516
516
}
517
517
518
+ char tf_devicestr [256 ];
519
+ int devicestr_len = strlen (mctxs [0 ]-> model -> devicestr );
520
+ if (strncasecmp (mctxs [0 ]-> model -> devicestr , "CPU" , 3 ) == 0 ) {
521
+ sprintf (tf_devicestr , "/device:CPU:0" );
522
+ }
523
+ else if (devicestr_len == 3 ) {
524
+ sprintf (tf_devicestr , "/device:%s:0" , mctxs [0 ]-> model -> devicestr );
525
+ }
526
+ else {
527
+ sprintf (tf_devicestr , "/device:%s" , mctxs [0 ]-> model -> devicestr );
528
+ }
529
+
518
530
for (size_t i = 0 ; i < ninputs ; ++ i ) {
519
531
RAI_Tensor * batched_input_tensors [nbatches ];
520
532
@@ -530,11 +542,19 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
530
542
RedisModule_Free (errorMessage );
531
543
return 1 ;
532
544
}
533
- // TODO EAGER
545
+
534
546
inputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (inputTensorsHandles [i ],
535
547
mctxs [0 ]-> model -> session ,
536
- mctxs [ 0 ] -> model -> devicestr ,
548
+ tf_devicestr ,
537
549
status );
550
+
551
+ if (TF_GetCode (status ) != TF_OK ) {
552
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
553
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
554
+ TF_DeleteStatus (status );
555
+ RedisModule_Free (errorMessage );
556
+ return 1 ;
557
+ }
538
558
}
539
559
540
560
TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
@@ -555,8 +575,6 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
555
575
return 1 ;
556
576
}
557
577
558
- // TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
559
-
560
578
int noutputs_ = noutputs ;
561
579
TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
562
580
if (TF_GetCode (status ) != TF_OK ) {
@@ -583,7 +601,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
583
601
for (size_t i = 0 ; i < noutputs ; ++ i ) {
584
602
outputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (outputTensorsHandles [i ],
585
603
mctxs [0 ]-> model -> session ,
586
- "CPU" ,
604
+ "/device: CPU:0 " ,
587
605
status );
588
606
589
607
outputTensorsValues [i ] = TFE_TensorHandleResolve (outputTensorsHandles [i ], status );
0 commit comments