Skip to content

Commit 00de68c

Browse files
committed
Explicitly copy to device
1 parent b64dd6b commit 00de68c

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/backends/tensorflow.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
315315
TFE_ContextOptions *context_opts = TFE_NewContextOptions();
316316
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
317317
// TFE_ContextOptionsSetAsync(context_opts, 0);
318-
// TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
318+
TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
319319

320320
TFE_Context *context = TFE_NewContext(context_opts, status);
321321
if (TF_GetCode(status) != TF_OK) {
@@ -530,6 +530,11 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
530530
RedisModule_Free(errorMessage);
531531
return 1;
532532
}
533+
// TODO EAGER
534+
inputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(inputTensorsHandles[i],
535+
mctxs[0]->model->session,
536+
mctxs[0]->model->devicestr,
537+
status);
533538
}
534539

535540
TFE_Op *fn_op = TFE_NewOp(mctxs[0]->model->session, RAI_TF_FN_NAME, status);
@@ -576,6 +581,11 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
576581
}
577582

578583
for (size_t i = 0; i < noutputs; ++i) {
584+
outputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(outputTensorsHandles[i],
585+
mctxs[0]->model->session,
586+
"CPU",
587+
status);
588+
579589
outputTensorsValues[i] = TFE_TensorHandleResolve(outputTensorsHandles[i], status);
580590

581591
if (TF_GetCode(status) != TF_OK) {

0 commit comments

Comments
 (0)