Skip to content

Commit aa6cd31

Browse files
committed
Fixes to device allocation to make CPU tests pass
1 parent 00de68c commit aa6cd31

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

src/backends/tensorflow.c

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,18 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
515515
}
516516
}
517517

518+
char tf_devicestr[256];
519+
int devicestr_len = strlen(mctxs[0]->model->devicestr);
520+
if (strncasecmp(mctxs[0]->model->devicestr, "CPU", 3) == 0) {
521+
sprintf(tf_devicestr, "/device:CPU:0");
522+
}
523+
else if (devicestr_len == 3) {
524+
sprintf(tf_devicestr, "/device:%s:0", mctxs[0]->model->devicestr);
525+
}
526+
else {
527+
sprintf(tf_devicestr, "/device:%s", mctxs[0]->model->devicestr);
528+
}
529+
518530
for (size_t i = 0; i < ninputs; ++i) {
519531
RAI_Tensor *batched_input_tensors[nbatches];
520532

@@ -530,11 +542,19 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
530542
RedisModule_Free(errorMessage);
531543
return 1;
532544
}
533-
// TODO EAGER
545+
534546
inputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(inputTensorsHandles[i],
535547
mctxs[0]->model->session,
536-
mctxs[0]->model->devicestr,
548+
tf_devicestr,
537549
status);
550+
551+
if (TF_GetCode(status) != TF_OK) {
552+
char *errorMessage = RedisModule_Strdup(TF_Message(status));
553+
RAI_SetError(error, RAI_EMODELRUN, errorMessage);
554+
TF_DeleteStatus(status);
555+
RedisModule_Free(errorMessage);
556+
return 1;
557+
}
538558
}
539559

540560
TFE_Op *fn_op = TFE_NewOp(mctxs[0]->model->session, RAI_TF_FN_NAME, status);
@@ -555,8 +575,6 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
555575
return 1;
556576
}
557577

558-
// TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
559-
560578
int noutputs_ = noutputs;
561579
TFE_Execute(fn_op, outputTensorsHandles, &noutputs_, status);
562580
if (TF_GetCode(status) != TF_OK) {
@@ -583,7 +601,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
583601
for (size_t i = 0; i < noutputs; ++i) {
584602
outputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(outputTensorsHandles[i],
585603
mctxs[0]->model->session,
586-
"CPU",
604+
"/device:CPU:0",
587605
status);
588606

589607
outputTensorsValues[i] = TFE_TensorHandleResolve(outputTensorsHandles[i], status);

tests/flow/test_serializations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_v0_torch_model(self):
8181
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"TORCH", b"CPU", b"PT_MINIMAL_V0", 0, 0, [b'a', b'b'], [b'']])
8282
torch_model_run(self.env, key_name)
8383

84-
def test_v0_troch_script(self):
84+
def test_v0_torch_script(self):
8585
key_name = "torch_script{1}"
8686
con = self.env.getConnection()
8787
script_rdb = b'\x07\x81\x00\x8f\xd2\t\x12\x0fL\x00\x05\x04CPU\x00\x05\x10TORCH_SCRIPT_V0\x00\x05\xc3@W@i\x0fdef bar(a, b):\n \x00\x0ereturn a + b\n\nd\x80 \x08_variadic@)\x12args : List[Tensor]\xe0\x06; \x02[0] A`\t\x031]\n\x00\x00\t\x00\x0b\xee\x04\xe7\x11\xaez\x91'

0 commit comments

Comments
 (0)