-
Notifications
You must be signed in to change notification settings - Fork 1.8k
test: Add L0_backend_onnxruntime test for enabling bfloat16 dtype in ONNXRuntime backend #8660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yinggeh
merged 19 commits into
main
from
yinggeh/tgh-26-onnx-backend-does-not-support-bfloat16-inputs
Mar 5, 2026
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
126fd7e
Add BF16 tests for ORT backend
yinggeh d92678d
Update related tests and model generations
yinggeh c952ede
Merge branch 'main' into yinggeh/tgh-26-onnx-backend-does-not-support…
yinggeh 570bc48
Fix test
yinggeh d238b8c
Merge branch 'main' of github.com:triton-inference-server/server into…
yinggeh c7d325c
Move bf16 and float32 to different PR
yinggeh a359237
Fix copyrights
yinggeh f400fe5
Add model generate script
yinggeh ea19f29
Update test
yinggeh a5ae694
Update test
yinggeh 7e0da9d
Fix precommit check
yinggeh 4824162
Update test
yinggeh 7c9cafe
Update test
yinggeh a555d14
Update tests
yinggeh 27084f6
Merge branch 'main' into yinggeh/tgh-26-onnx-backend-does-not-support…
yinggeh b9d6e30
Fix tests
yinggeh d6f4118
Merge branch 'yinggeh/tgh-26-onnx-backend-does-not-support-bfloat16-i…
yinggeh abdb260
exit test
yinggeh 356006c
Revert model generation
yinggeh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions | ||
| # are met: | ||
| # * Redistributions of source code must retain the above copyright | ||
| # notice, this list of conditions and the following disclaimer. | ||
| # * Redistributions in binary form must reproduce the above copyright | ||
| # notice, this list of conditions and the following disclaimer in the | ||
| # documentation and/or other materials provided with the distribution. | ||
| # * Neither the name of NVIDIA CORPORATION nor the names of its | ||
| # contributors may be used to endorse or promote products derived | ||
| # from this software without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
| # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
| # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
| # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
| # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
| # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
| # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
| # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
|
||
|
|
||
| # Generates the add_bf16 ONNX model and Triton config. | ||
| # Model: element-wise Add in BFLOAT16 (INPUT0 + INPUT1 = OUTPUT), ONNX Runtime backend. | ||
| import os | ||
|
|
||
| import onnx | ||
|
|
||
|
|
||
| def generate_bf16_add_model(models_dir): | ||
| """Generate a simple BFLOAT16 Add model (INPUT0 + INPUT1 = OUTPUT).""" | ||
| model_name = "add_bf16" | ||
| shape = [1] | ||
| onnx_dtype = onnx.TensorProto.BFLOAT16 | ||
|
|
||
| add = onnx.helper.make_node("Add", ["INPUT0", "INPUT1"], ["OUTPUT"]) | ||
|
|
||
| input0 = onnx.helper.make_tensor_value_info("INPUT0", onnx_dtype, shape) | ||
| input1 = onnx.helper.make_tensor_value_info("INPUT1", onnx_dtype, shape) | ||
| output = onnx.helper.make_tensor_value_info("OUTPUT", onnx_dtype, shape) | ||
|
|
||
| graph_proto = onnx.helper.make_graph( | ||
| [add], | ||
| model_name, | ||
| [input0, input1], | ||
| [output], | ||
| ) | ||
| model_def = onnx.helper.make_model(graph_proto, producer_name="triton") | ||
| # Cap IR version for older ONNX Runtime (e.g. max supported 11) | ||
| model_def.ir_version = min(model_def.ir_version, 11) | ||
| # BFLOAT16 support requires opset 13+ | ||
| model_def.opset_import[0].version = 13 | ||
|
|
||
| model_dir = os.path.join(models_dir, model_name, "1") | ||
| os.makedirs(model_dir, exist_ok=True) | ||
| onnx.save(model_def, os.path.join(model_dir, "model.onnx")) | ||
|
|
||
| # Write config.pbtxt | ||
| config = """platform: "onnxruntime_onnx" | ||
| max_batch_size: 0 | ||
| input [ | ||
| {{ | ||
| name: "INPUT0" | ||
| data_type: TYPE_BF16 | ||
| dims: {shape} | ||
| }}, | ||
| {{ | ||
| name: "INPUT1" | ||
| data_type: TYPE_BF16 | ||
| dims: {shape} | ||
| }} | ||
| ] | ||
| output [ | ||
| {{ | ||
| name: "OUTPUT" | ||
| data_type: TYPE_BF16 | ||
| dims: {shape} | ||
| }} | ||
| ] | ||
| """.format( | ||
| shape=shape | ||
| ) | ||
|
|
||
| config_path = os.path.join(models_dir, model_name, "config.pbtxt") | ||
| with open(config_path, "w") as f: | ||
| f.write(config) | ||
|
|
||
| print(f"Generated model '{model_name}' in {models_dir}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| models_dir = os.path.join(os.getcwd(), "models") | ||
| os.makedirs(models_dir, exist_ok=True) | ||
| generate_bf16_add_model(models_dir) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions | ||
| # are met: | ||
| # * Redistributions of source code must retain the above copyright | ||
| # notice, this list of conditions and the following disclaimer. | ||
| # * Redistributions in binary form must reproduce the above copyright | ||
| # notice, this list of conditions and the following disclaimer in the | ||
| # documentation and/or other materials provided with the distribution. | ||
| # * Neither the name of NVIDIA CORPORATION nor the names of its | ||
| # contributors may be used to endorse or promote products derived | ||
| # from this software without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
| # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
| # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
| # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
| # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
| # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
| # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
| # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
|
||
| import os | ||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import tritonclient.grpc as grpcclient | ||
| import tritonclient.http as httpclient | ||
|
|
||
|
|
||
| class BFloat16Test(unittest.TestCase): | ||
| def setUp(self): | ||
| self.protocol = os.environ.get("CLIENT_TYPE", "http") | ||
| if self.protocol == "http": | ||
| self.client_ = httpclient.InferenceServerClient("localhost:8000") | ||
| else: | ||
| self.client_ = grpcclient.InferenceServerClient("localhost:8001") | ||
| self.model_name_ = "add_bf16" | ||
| self.shape_ = [1] | ||
|
|
||
| def _infer_bf16(self, input0_data, input1_data): | ||
| """Helper to run BF16 inference and return the output numpy array.""" | ||
| if self.protocol == "http": | ||
| input0 = httpclient.InferInput("INPUT0", self.shape_, "BF16") | ||
| input1 = httpclient.InferInput("INPUT1", self.shape_, "BF16") | ||
| else: | ||
| input0 = grpcclient.InferInput("INPUT0", self.shape_, "BF16") | ||
| input1 = grpcclient.InferInput("INPUT1", self.shape_, "BF16") | ||
| input0.set_data_from_numpy(input0_data) | ||
| input1.set_data_from_numpy(input1_data) | ||
|
|
||
| results = self.client_.infer(self.model_name_, [input0, input1]) | ||
| return results.as_numpy("OUTPUT") | ||
|
|
||
| def test_bf16_add_variants(self): | ||
| """Run BF16 add across multiple cases: zeros, negatives, large, small, cancellation, and identical.""" | ||
| for input0_val, input1_val, expected_val in [ | ||
| (0.0, 0.0, 0.0), # zeros | ||
| (-1.5, 3.5, 2.0), # negatives / mixed | ||
| (100.0, 200.0, 300.0), # large | ||
| (1e-2, 1e-2, 2e-2), # small (near underflow) | ||
| (1.0, -1.0, 0.0), # cancellation | ||
| (2.0, 2.0, 4.0), # identical inputs | ||
| ]: | ||
| output = self._infer_bf16( | ||
| np.full(self.shape_, input0_val, dtype=np.float32), | ||
| np.full(self.shape_, input1_val, dtype=np.float32), | ||
| ) | ||
| self.assertEqual(output.dtype, np.float32) | ||
| # TODO: BF16 to FP32 conversion loses precision. Remove rtol and atol in TRI-801. | ||
| # BF16 has ~3 decimal digits; use relaxed tol for computed values | ||
| np.testing.assert_allclose(output, expected_val, rtol=1e-2, atol=1e-3) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| #!/bin/bash | ||
| # Copyright 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions | ||
| # are met: | ||
| # * Redistributions of source code must retain the above copyright | ||
| # notice, this list of conditions and the following disclaimer. | ||
| # * Redistributions in binary form must reproduce the above copyright | ||
| # notice, this list of conditions and the following disclaimer in the | ||
| # documentation and/or other materials provided with the distribution. | ||
| # * Neither the name of NVIDIA CORPORATION nor the names of its | ||
| # contributors may be used to endorse or promote products derived | ||
| # from this software without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
| # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
| # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
| # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
| # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
| # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
| # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
| # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
|
||
| export CUDA_VISIBLE_DEVICES=0 | ||
|
|
||
| SERVER=/opt/tritonserver/bin/tritonserver | ||
| SERVER_LOG="./inference_server.log" | ||
| CLIENT_LOG="./test.log" | ||
| source ../common/util.sh | ||
|
|
||
| rm -f *.log | ||
| rm -rf models | ||
|
|
||
| RET=0 | ||
|
|
||
| # BFLOAT16 test | ||
| # Generate the model | ||
| mkdir -p models/add_bf16/1 | ||
| set +e | ||
|
|
||
| pip install onnx==1.20.1 | ||
| if [ $? -ne 0 ]; then | ||
| echo -e "\n***\n*** Failed to install onnx dependency\n***" | ||
| exit 1 | ||
| fi | ||
|
|
||
| python gen_add_bf16_onnx_model.py | ||
| if [ $? -ne 0 ]; then | ||
| echo -e "\n***\n*** Failed to generate BFLOAT16 ONNX model\n***" | ||
| exit 1 | ||
| fi | ||
|
|
||
| set -e | ||
|
yinggeh marked this conversation as resolved.
|
||
|
|
||
| SERVER_ARGS="--model-repository=`pwd`/models" | ||
| run_server | ||
| if [ "$SERVER_PID" == "0" ]; then | ||
| echo -e "\n***\n*** Failed to start $SERVER\n***" | ||
| cat $SERVER_LOG | ||
| exit 1 | ||
| fi | ||
|
|
||
| set +e | ||
|
|
||
| for client_type in http grpc; do | ||
| export CLIENT_TYPE=$client_type | ||
| CLIENT_LOG="./test_${client_type}.log" | ||
| python test.py >>$CLIENT_LOG 2>&1 | ||
| if [ $? -ne 0 ]; then | ||
| cat $CLIENT_LOG | ||
| echo -e "\n***\n*** Test Failed ($client_type)\n***" | ||
| RET=1 | ||
| fi | ||
| done | ||
| unset CLIENT_TYPE | ||
|
|
||
| set -e | ||
|
|
||
| kill $SERVER_PID | ||
| wait $SERVER_PID | ||
|
yinggeh marked this conversation as resolved.
|
||
|
|
||
| if [ $RET -eq 0 ]; then | ||
| echo -e "\n***\n*** Test Passed\n***" | ||
| else | ||
| echo -e "\n***\n*** Test FAILED\n***" | ||
| fi | ||
|
|
||
| exit $RET | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.