Skip to content

Commit

Permalink
Fix trailing zeros for type BYTES (#2551)
Browse files Browse the repository at this point in the history
* Fix trailing zeros for type BYTES

* Fix unit tests to work with np.object_ serialization

* Review edits + CI

* Move get_number_of_bytes_for_npobject to inference utils

* Fix up
  • Loading branch information
Tabrizian authored Feb 26, 2021
1 parent e42cac9 commit f0bdb93
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 108 deletions.
36 changes: 20 additions & 16 deletions python/examples/ensemble_image_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def parse_model_grpc(model_metadata, model_config):
input_metadata = model_metadata.inputs[0]
output_metadata = model_metadata.outputs

return (input_metadata.name, output_metadata,
model_config.max_batch_size)
return (input_metadata.name, output_metadata, model_config.max_batch_size)


def parse_model_http(model_metadata, model_config):
Expand Down Expand Up @@ -101,7 +100,7 @@ def postprocess(results, output_names, filenames, batch_size):
for output_name in output_names:
print(' [{}]:'.format(output_name))
for result in output_dict[output_name][n]:
if output_dict[output_name][n].dtype.type == np.bytes_:
if output_dict[output_name][n].dtype.type == np.object_:
cls = "".join(chr(x) for x in result).split(':')
else:
cls = result.split(':')
Expand All @@ -116,12 +115,13 @@ def postprocess(results, output_names, filenames, batch_size):
required=False,
default=False,
help='Enable verbose output')
parser.add_argument('-m',
'--model-name',
type=str,
required=False,
default='preprocess_inception_ensemble',
help='Name of model. Default is preprocess_inception_ensemble.')
parser.add_argument(
'-m',
'--model-name',
type=str,
required=False,
default='preprocess_inception_ensemble',
help='Name of model. Default is preprocess_inception_ensemble.')
parser.add_argument('-c',
'--classes',
type=int,
Expand Down Expand Up @@ -237,18 +237,22 @@ def postprocess(results, output_names, filenames, batch_size):
"BYTES"))
inputs[0].set_data_from_numpy(batched_image_data, binary_data=True)

output_names = [ output.name if FLAGS.protocol.lower() == "grpc"
else output['name'] for output in output_metadata ]
output_names = [
output.name if FLAGS.protocol.lower() == "grpc" else output['name']
for output in output_metadata
]

outputs = []
for output_name in output_names:
if FLAGS.protocol.lower() == "grpc":
outputs.append(grpcclient.InferRequestedOutput(output_name,
class_count=FLAGS.classes))
outputs.append(
grpcclient.InferRequestedOutput(output_name,
class_count=FLAGS.classes))
else:
outputs.append(httpclient.InferRequestedOutput(output_name,
binary_data=True,
class_count=FLAGS.classes))
outputs.append(
httpclient.InferRequestedOutput(output_name,
binary_data=True,
class_count=FLAGS.classes))

# Send request
result = triton_client.infer(model_name, inputs, outputs=outputs)
Expand Down
2 changes: 1 addition & 1 deletion python/examples/grpc_image_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def deserialize_bytes_tensor(encoded_tensor):
sb = struct.unpack_from("<{}s".format(l), val_buf, offset)[0]
offset += l
strs.append(sb)
return (np.array(strs, dtype=bytes))
return (np.array(strs, dtype=np.object_))


def parse_model(model_metadata, model_config):
Expand Down
2 changes: 1 addition & 1 deletion python/examples/image_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def postprocess(results, output_name, batch_size, batching):
if not batching:
results = [results]
for result in results:
if output_array.dtype.type == np.bytes_:
if output_array.dtype.type == np.object_:
cls = "".join(chr(x) for x in result).split(':')
else:
cls = result.split(':')
Expand Down
6 changes: 3 additions & 3 deletions python/examples/simple_grpc_sequence_stream_infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def async_stream_send(triton_client, values, batch_size, sequence_id,
recv_count = recv_count + 1

for i in range(len(result0_list)):
seq0_expected = 1 if (i == 0) else values[i-1]
seq1_expected = 101 if (i == 0) else values[i-1] * -1
seq0_expected = 1 if (i == 0) else values[i - 1]
seq1_expected = 101 if (i == 0) else values[i - 1] * -1
# The dyna_sequence custom backend adds the correlation ID
# to the last request in a sequence.
if FLAGS.dyna and (i != 0) and (values[i-1] == 1):
if FLAGS.dyna and (i != 0) and (values[i - 1] == 1):
seq0_expected += sequence_id0
seq1_expected += sequence_id1

Expand Down
6 changes: 3 additions & 3 deletions python/examples/simple_grpc_sequence_sync_infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def sync_send(triton_client, result_list, values, batch_size, sequence_id,
sys.exit(1)

for i in range(len(result0_list)):
seq0_expected = 1 if (i == 0) else values[i-1]
seq1_expected = 101 if (i == 0) else values[i-1] * -1
seq0_expected = 1 if (i == 0) else values[i - 1]
seq1_expected = 101 if (i == 0) else values[i - 1] * -1
# The dyna_sequence custom backend adds the correlation ID
# to the last request in a sequence.
if FLAGS.dyna and (i != 0) and (values[i-1] == 1):
if FLAGS.dyna and (i != 0) and (values[i - 1] == 1):
seq0_expected += sequence_id0
seq1_expected += sequence_id1

Expand Down
34 changes: 20 additions & 14 deletions python/examples/simple_grpc_shm_string_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,29 @@
# Create the data for the two input tensors. Initialize the first
# to unique integers and the second to all ones.
in0 = np.arange(start=0, stop=16, dtype=np.int32)
in0n = np.array([str(x) for x in in0.flatten()], dtype=object)
in0n = np.array([str(x).encode('utf-8') for x in in0.flatten()],
dtype=object)
input0_data = in0n.reshape(in0.shape)
in1 = np.ones(shape=16, dtype=np.int32)
in1n = np.array([str(x) for x in in1.flatten()], dtype=object)
in1n = np.array([str(x).encode('utf-8') for x in in1.flatten()],
dtype=object)
input1_data = in1n.reshape(in1.shape)

expected_sum = np.array([str(x) for x in np.add(in0, in1).flatten()],
dtype=object)
expected_diff = np.array([str(x) for x in np.subtract(in0, in1).flatten()],
dtype=object)
expected_sum = np.array(
[str(x).encode('utf-8') for x in np.add(in0, in1).flatten()],
dtype=object)
expected_diff = np.array(
[str(x).encode('utf-8') for x in np.subtract(in0, in1).flatten()],
dtype=object)
expected_sum_serialized = utils.serialize_byte_tensor(expected_sum)
expected_diff_serialized = utils.serialize_byte_tensor(expected_diff)

input0_data_serialized = utils.serialize_byte_tensor(input0_data)
input1_data_serialized = utils.serialize_byte_tensor(input1_data)
input0_byte_size = input0_data_serialized.size * input0_data_serialized.itemsize
input1_byte_size = input1_data_serialized.size * input1_data_serialized.itemsize
output0_byte_size = expected_sum_serialized.size * expected_sum_serialized.itemsize
output1_byte_size = expected_diff_serialized.size * expected_diff_serialized.itemsize
input0_byte_size = utils.serialized_byte_size(input0_data_serialized)
input1_byte_size = utils.serialized_byte_size(input1_data_serialized)
output0_byte_size = utils.serialized_byte_size(expected_sum_serialized)
output1_byte_size = utils.serialized_byte_size(expected_diff_serialized)
output_byte_size = max(input0_byte_size, input1_byte_size) + 1

# Create Output0 and Output1 in Shared Memory and store shared memory handles
Expand Down Expand Up @@ -171,10 +175,12 @@
sys.exit(1)

for i in range(16):
r0 = output0_data[0][i].decode("utf-8")
r1 = output1_data[0][i].decode("utf-8")
print(str(input0_data[i]) + " + " + str(input1_data[i]) + " = " + r0)
print(str(input0_data[i]) + " - " + str(input1_data[i]) + " = " + r1)
r0 = output0_data[0][i]
r1 = output1_data[0][i]
print(
str(input0_data[i]) + " + " + str(input1_data[i]) + " = " + str(r0))
print(
str(input0_data[i]) + " - " + str(input1_data[i]) + " = " + str(r1))

if expected_sum[i] != r0:
print("shm infer error: incorrect sum")
Expand Down
6 changes: 4 additions & 2 deletions python/examples/simple_grpc_string_infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@
expected_sum = np.add(in0, in1)
expected_diff = np.subtract(in0, in1)

in0n = np.array([str(x) for x in in0.reshape(in0.size)], dtype=object)
in0n = np.array([str(x).encode('utf-8') for x in in0.reshape(in0.size)],
dtype=object)
input0_data = in0n.reshape(in0.shape)
in1n = np.array([str(x) for x in in1.reshape(in1.size)], dtype=object)
in1n = np.array([str(x).encode('utf-8') for x in in1.reshape(in1.size)],
dtype=object)
input1_data = in1n.reshape(in1.shape)

# Initialize the data
Expand Down
6 changes: 3 additions & 3 deletions python/examples/simple_http_sequence_sync_infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def sync_send(triton_client, result_list, values, batch_size, sequence_id,
sys.exit(1)

for i in range(len(result0_list)):
seq0_expected = 1 if (i == 0) else values[i-1]
seq1_expected = 101 if (i == 0) else values[i-1] * -1
seq0_expected = 1 if (i == 0) else values[i - 1]
seq1_expected = 101 if (i == 0) else values[i - 1] * -1
# The dyna_sequence custom backend adds the correlation ID
# to the last request in a sequence.
if FLAGS.dyna and (i != 0) and (values[i-1] == 1):
if FLAGS.dyna and (i != 0) and (values[i - 1] == 1):
seq0_expected += sequence_id0
seq1_expected += sequence_id1

Expand Down
35 changes: 20 additions & 15 deletions python/examples/simple_http_shm_string_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,29 @@
# Create the data for the two input tensors. Initialize the first
# to unique integers and the second to all ones.
in0 = np.arange(start=0, stop=16, dtype=np.int32)
in0n = np.array([str(x) for x in in0.flatten()], dtype=object)
in0n = np.array([str(x).encode('utf-8') for x in in0.flatten()],
dtype=object)
input0_data = in0n.reshape(in0.shape)
in1 = np.ones(shape=16, dtype=np.int32)
in1n = np.array([str(x) for x in in1.flatten()], dtype=object)
in1n = np.array([str(x).encode('utf-8') for x in in1.flatten()],
dtype=object)
input1_data = in1n.reshape(in1.shape)

expected_sum = np.array([str(x) for x in np.add(in0, in1).flatten()],
dtype=object)
expected_diff = np.array([str(x) for x in np.subtract(in0, in1).flatten()],
dtype=object)
expected_sum = np.array(
[str(x).encode('utf-8') for x in np.add(in0, in1).flatten()],
dtype=object)
expected_diff = np.array(
[str(x).encode('utf-8') for x in np.subtract(in0, in1).flatten()],
dtype=object)
expected_sum_serialized = utils.serialize_byte_tensor(expected_sum)
expected_diff_serialized = utils.serialize_byte_tensor(expected_diff)

input0_data_serialized = utils.serialize_byte_tensor(input0_data)
input1_data_serialized = utils.serialize_byte_tensor(input1_data)
input0_byte_size = input0_data_serialized.size * input0_data_serialized.itemsize
input1_byte_size = input1_data_serialized.size * input1_data_serialized.itemsize
output0_byte_size = expected_sum_serialized.size * expected_sum_serialized.itemsize
output1_byte_size = expected_diff_serialized.size * expected_diff_serialized.itemsize
output_byte_size = max(input0_byte_size, input1_byte_size) + 1
input0_byte_size = utils.serialized_byte_size(input0_data_serialized)
input1_byte_size = utils.serialized_byte_size(input1_data_serialized)
output0_byte_size = utils.serialized_byte_size(expected_sum_serialized)
output1_byte_size = utils.serialized_byte_size(expected_diff_serialized)

# Create Output0 and Output1 in Shared Memory and store shared memory handles
shm_op0_handle = shm.create_shared_memory_region("output0_data",
Expand Down Expand Up @@ -171,10 +174,12 @@
sys.exit(1)

for i in range(16):
r0 = output0_data[0][i].decode("utf-8")
r1 = output1_data[0][i].decode("utf-8")
print(str(input0_data[i]) + " + " + str(input1_data[i]) + " = " + r0)
print(str(input0_data[i]) + " - " + str(input1_data[i]) + " = " + r1)
r0 = output0_data[0][i]
r1 = output1_data[0][i]
print(
str(input0_data[i]) + " + " + str(input1_data[i]) + " = " + str(r0))
print(
str(input0_data[i]) + " - " + str(input1_data[i]) + " = " + str(r1))

if expected_sum[i] != r0:
print("shm infer error: incorrect sum")
Expand Down
28 changes: 21 additions & 7 deletions python/examples/simple_http_string_infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@

import tritonclient.http as httpclient

# unicode() doesn't exist on python3, for how we use it the
# corresponding function is bytes()
if sys.version_info.major == 3:
unicode = bytes


def TestIdentityInference(np_array, binary_data):
model_name = "simple_identity"
Expand All @@ -46,14 +51,20 @@ def TestIdentityInference(np_array, binary_data):
results = triton_client.infer(model_name=model_name,
inputs=inputs,
outputs=outputs)
if (np_array.dtype == np.object):
if (np_array.dtype == np.object_):
print(results.as_numpy('OUTPUT0'))
if binary_data:
if not np.array_equal(np_array,
np.char.decode(results.as_numpy('OUTPUT0'))):
if not np.array_equal(np_array, results.as_numpy('OUTPUT0')):
print(results.as_numpy('OUTPUT0'))
sys.exit(1)
else:
if not np.array_equal(np_array, results.as_numpy('OUTPUT0')):
expected_array = np.array([
unicode(str(x), encoding='utf-8')
for x in results.as_numpy('OUTPUT0').flatten()
],
dtype=object)
expected_array = expected_array.reshape([1, 16])
if not np.array_equal(np_array, expected_array):
print(results.as_numpy('OUTPUT0'))
sys.exit(1)
else:
Expand Down Expand Up @@ -102,9 +113,11 @@ def TestIdentityInference(np_array, binary_data):
expected_sum = np.add(in0, in1)
expected_diff = np.subtract(in0, in1)

in0n = np.array([str(x) for x in in0.reshape(in0.size)], dtype=object)
in0n = np.array([str(x).encode('utf-8') for x in in0.reshape(in0.size)],
dtype=object)
input0_data = in0n.reshape(in0.shape)
in1n = np.array([str(x) for x in in1.reshape(in1.size)], dtype=object)
in1n = np.array([str(x).encode('utf-8') for x in in1.reshape(in1.size)],
dtype=object)
input1_data = in1n.reshape(in1.shape)

# Initialize the data
Expand Down Expand Up @@ -142,7 +155,8 @@ def TestIdentityInference(np_array, binary_data):
sys.exit(1)

# Test with null character
null_chars_array = np.array(["he\x00llo" for i in range(16)], dtype=object)
null_chars_array = np.array(
["he\x00llo".encode('utf-8') for i in range(16)], dtype=object)
null_char_data = null_chars_array.reshape([1, 16])
TestIdentityInference(null_char_data, True) # Using binary data
TestIdentityInference(null_char_data, False) # Using JSON data
Expand Down
6 changes: 5 additions & 1 deletion python/library/tritonclient/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,11 @@ def set_data_from_numpy(self, input_tensor):
self._input.parameters.pop('shared_memory_offset', None)

if self._input.datatype == "BYTES":
self._raw_content = serialize_byte_tensor(input_tensor).tobytes()
serialized_output = serialize_byte_tensor(input_tensor)
if serialized_output.size > 0:
self._raw_content = serialized_output.item()
else:
self._raw_content = b''
else:
self._raw_content = input_tensor.tobytes()

Expand Down
6 changes: 5 additions & 1 deletion python/library/tritonclient/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,11 @@ def set_data_from_numpy(self, input_tensor, binary_data=True):
else:
self._data = None
if self._datatype == "BYTES":
self._raw_data = serialize_byte_tensor(input_tensor).tobytes()
serialized_output = serialize_byte_tensor(input_tensor)
if serialized_output.size > 0:
self._raw_data = serialized_output.item()
else:
self._raw_data = b''
else:
self._raw_data = input_tensor.tobytes()
self._parameters['binary_data_size'] = len(self._raw_data)
Expand Down
Loading

0 comments on commit f0bdb93

Please sign in to comment.