-
Notifications
You must be signed in to change notification settings - Fork 255
feat: Enable bfloat16 input/output tensor dtype in Python client #880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # Copyright 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions | ||
|
|
@@ -26,13 +26,7 @@ | |
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
| import numpy as np | ||
| from tritonclient.utils import ( | ||
| np_to_triton_dtype, | ||
| raise_error, | ||
| serialize_bf16_tensor, | ||
| serialize_byte_tensor, | ||
| triton_to_np_dtype, | ||
| ) | ||
| from tritonclient.utils import np_to_triton_dtype, raise_error, serialize_byte_tensor | ||
|
|
||
|
|
||
| class InferInput: | ||
|
|
@@ -129,22 +123,13 @@ def set_data_from_numpy(self, input_tensor, binary_data=True): | |
| """ | ||
| if not isinstance(input_tensor, (np.ndarray,)): | ||
| raise_error("input_tensor must be a numpy array") | ||
| # DLIS-3986: Special handling for bfloat16 until Numpy officially supports it | ||
| if self._datatype == "BF16": | ||
| if input_tensor.dtype != triton_to_np_dtype(self._datatype): | ||
| raise_error( | ||
| "got unexpected datatype {} from numpy array, expected {} for BF16 type".format( | ||
| input_tensor.dtype, triton_to_np_dtype(self._datatype) | ||
| ) | ||
| ) | ||
| else: | ||
| dtype = np_to_triton_dtype(input_tensor.dtype) | ||
| if self._datatype != dtype: | ||
| raise_error( | ||
| "got unexpected datatype {} from numpy array, expected {}".format( | ||
| dtype, self._datatype | ||
| ) | ||
| dtype = np_to_triton_dtype(input_tensor.dtype) | ||
| if self._datatype != dtype: | ||
| raise_error( | ||
| "got unexpected datatype {} from numpy array, expected {}".format( | ||
| dtype, self._datatype | ||
| ) | ||
| ) | ||
|
Comment on lines
+126
to
+132
|
||
| valid_shape = True | ||
| if len(self._shape) != len(input_tensor.shape): | ||
| valid_shape = False | ||
|
|
@@ -202,12 +187,6 @@ def set_data_from_numpy(self, input_tensor, binary_data=True): | |
| self._raw_data = serialized_output.item() | ||
| else: | ||
| self._raw_data = b"" | ||
| elif self._datatype == "BF16": | ||
| serialized_output = serialize_bf16_tensor(input_tensor) | ||
| if serialized_output.size > 0: | ||
| self._raw_data = serialized_output.item() | ||
| else: | ||
| self._raw_data = b"" | ||
| else: | ||
| self._raw_data = input_tensor.tobytes() | ||
| self._parameters["binary_data_size"] = len(self._raw_data) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change makes BF16 inputs require an actual BF16 numpy dtype (via
ml_dtypes.bfloat16mapping). Previously, BF16 inputs could be provided asfloat32(sincetriton_to_np_dtype("BF16")mapped tonp.float32) and the client handled BF16 serialization. If existing users rely on the float32 workaround, consider accepting both float32 and bfloat16 for BF16 inputs (converting float32 to BF16 bytes) or clearly documenting the breaking behavior change.