Skip to content

Commit 1e2cb52

Browse files
committed
Fix test
1 parent 0f86380 commit 1e2cb52

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

src/backends/tensorflow.c

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ typedef struct TFDLManagedTensorCtx TFDLManagedTensorCtx;
3939

4040
TFDLManagedTensorCtx *TFDLManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status *status) {
4141
TFDLManagedTensorCtx *ctx = RedisModule_Alloc(sizeof(TFDLManagedTensorCtx));
42+
ctx->reference = h;
4243
ctx->ndim = TFE_TensorHandleNumDims(h, status);
4344
ctx->shape = RedisModule_Calloc(ctx->ndim, sizeof(int64_t));
4445
ctx->strides = RedisModule_Calloc(ctx->ndim, sizeof(int64_t));
@@ -257,15 +258,9 @@ void *TFE_HandleToDLPack(TFE_TensorHandle *h, TF_Status *status) {
257258
dlm_tensor->dl_tensor.data = tf_dlm_data;
258259
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
259260
dlm_tensor->dl_tensor.shape = tf_dlm_tensor_ctx->shape;
260-
// There are two ways to represent compact row-major data
261-
// 1) nullptr indicates tensor is compact and row-majored.
262-
// 2) fill in the strides array as the real case for compact row-major data.
263-
// Here we choose option 2, since some frameworks didn't handle the strides
264-
// argument properly.
265261
dlm_tensor->dl_tensor.strides = tf_dlm_tensor_ctx->strides;
266-
267-
// TF doesn't handle the strides and byte_offsets here
268262
dlm_tensor->dl_tensor.byte_offset = 0;
263+
269264
return (void *)dlm_tensor;
270265
}
271266

tests/flow/tests_dag.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,11 +818,8 @@ def test_dag_modelrun_financialNet_no_writes(env):
818818
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
819819
env.assertEqual(ret, b'OK')
820820

821-
MAX_TRANSACTIONS = 2
822-
823821
for tensor_number in range(1, MAX_TRANSACTIONS):
824-
# for repetition in range(1, 10):
825-
for repetition in range(1, 2):
822+
for repetition in range(1, 10):
826823
reference_tensor = creditcard_referencedata[tensor_number]
827824
transaction_tensor = creditcard_transactions[tensor_number]
828825
result_tensor_keyname = 'resultTensor{{hhh}}{}'.format(tensor_number)
@@ -836,9 +833,6 @@ def test_dag_modelrun_financialNet_no_writes(env):
836833
ret = con.execute_command("EXISTS {}".format(reference_tensor_keyname))
837834
env.assertEqual(ret, 1)
838835

839-
# print(reference_tensor)
840-
print(transaction_tensor)
841-
842836
ret = con.execute_command(
843837
'AI.DAGRUN', 'LOAD', '1', reference_tensor_keyname, '|>',
844838
'AI.TENSORSET', transaction_tensor_keyname, 'FLOAT', 1, 30,'BLOB', transaction_tensor.tobytes(), '|>',

0 commit comments

Comments
 (0)