@@ -315,7 +315,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
315
315
TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
316
316
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
317
317
// TFE_ContextOptionsSetAsync(context_opts, 0);
318
- // TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
318
+ TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
319
319
320
320
TFE_Context * context = TFE_NewContext (context_opts , status );
321
321
if (TF_GetCode (status ) != TF_OK ) {
@@ -530,6 +530,11 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
530
530
RedisModule_Free (errorMessage );
531
531
return 1 ;
532
532
}
533
+ // TODO EAGER
534
+ inputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (inputTensorsHandles [i ],
535
+ mctxs [0 ]-> model -> session ,
536
+ mctxs [0 ]-> model -> devicestr ,
537
+ status );
533
538
}
534
539
535
540
TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
@@ -576,6 +581,11 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
576
581
}
577
582
578
583
for (size_t i = 0 ; i < noutputs ; ++ i ) {
584
+ outputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (outputTensorsHandles [i ],
585
+ mctxs [0 ]-> model -> session ,
586
+ "CPU" ,
587
+ status );
588
+
579
589
outputTensorsValues [i ] = TFE_TensorHandleResolve (outputTensorsHandles [i ], status );
580
590
581
591
if (TF_GetCode (status ) != TF_OK ) {
0 commit comments