-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve TensorRT GetCapability to Enable More Models #1012
Changes from 1 commit
747ee82
75a2a44
5b8c6d9
398a0d7
2b2a0c1
4e378e4
447ec41
d432724
3b8ed2c
e1c8074
eadc3eb
12360a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,9 +69,8 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph | |
if (it != fused_outputs.end()) { | ||
fused_outputs.erase(it); | ||
erased.insert(input); | ||
} | ||
//only when input is neither in output list nor erased list, add the input to input list | ||
else if (erased.find(input) == erased.end()) { | ||
} else if (erased.find(input) == erased.end()) { | ||
//only when input is neither in output list nor erased list, add the input to input list | ||
fused_inputs[input] = input_order++; | ||
} | ||
} | ||
|
@@ -462,22 +461,44 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime: | |
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]); | ||
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); | ||
const auto& tensor_shape = ort.GetTensorShape(tensor_info); | ||
auto tensor_type = ort.GetTensorElementType(tensor_info); | ||
ort.ReleaseTensorTypeAndShapeInfo(tensor_info); | ||
const float* input = ort.GetTensorData<float>(input_tensor); | ||
|
||
const int input_batch_size = tensor_shape[0]; | ||
if (i > 0 && batch_size != input_batch_size) { | ||
ORT_THROW("Input batch size is inconsistent"); | ||
} | ||
batch_size = input_batch_size; | ||
|
||
CHECK_CUDA(cudaMalloc(&buffers[i], input_batch_size * input_dim_sizes[i] * sizeof(float))); | ||
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_batch_size * input_dim_sizes[i] * sizeof(float), cudaMemcpyHostToDevice)); | ||
int input_size = batch_size * input_dim_sizes[i]; | ||
if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { | ||
const float* input = const_cast<float*>(ort.GetTensorData<float>(input_tensor)); | ||
CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(float))); | ||
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(float), cudaMemcpyHostToDevice)); | ||
} else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { | ||
const int8_t* input = const_cast<int8_t*>(ort.GetTensorData<int8_t>(input_tensor)); | ||
CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(int8_t))); | ||
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(int8_t), cudaMemcpyHostToDevice)); | ||
} else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { | ||
const int32_t* input = const_cast<int32_t*>(ort.GetTensorData<int32_t>(input_tensor)); | ||
CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(int32_t))); | ||
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(int32_t), cudaMemcpyHostToDevice)); | ||
} else { | ||
Status(common::ONNXRUNTIME, common::FAIL, "Input tensor type " + std::to_string(tensor_type) + " is not supported."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why allocate a Status() , doesn't seem used anywhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment for other Status() allocations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like this compute_func() always returns 0, which doens't seem correct. |
||
} | ||
} | ||
|
||
// Allocate cuda memory for outputs | ||
for (int i = 0, end = num_binding_outputs; i < end; ++i) { | ||
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(float))); | ||
if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { | ||
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(float))); | ||
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { | ||
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(int8_t))); | ||
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 || output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { | ||
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(int32_t))); | ||
} else { | ||
Status(common::ONNXRUNTIME, common::FAIL, "Output tensor type " + std::to_string(output_types[i]) + " is not supported."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above for Status() |
||
} | ||
} | ||
|
||
// Run TRT inference | ||
|
@@ -491,18 +512,22 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime: | |
|
||
int output_size = batch_size * output_dim_sizes[i]; | ||
OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index, output_shapes[i].data(), output_shapes[i].size()); | ||
// If output tensor type is INT64, TensorRT processes data as INT32 and the output will be converted to INT64. | ||
if (output_types[i] == TensorProto::FLOAT) { | ||
if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { | ||
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<float>(output_tensor), buffers[i + num_binding_inputs], output_size * sizeof(float), cudaMemcpyDeviceToHost)); | ||
} else if (output_types[i] == TensorProto::INT64) { | ||
int* output = new int[output_size]; | ||
CHECK_CUDA(cudaMemcpy(output, buffers[i + num_binding_inputs], output_size * sizeof(int), cudaMemcpyDeviceToHost)); | ||
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { | ||
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<int8_t>(output_tensor), buffers[i + num_binding_inputs], output_size * sizeof(int8_t), cudaMemcpyDeviceToHost)); | ||
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { | ||
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<int32_t>(output_tensor), buffers[i + num_binding_inputs], output_size * sizeof(int32_t), cudaMemcpyDeviceToHost)); | ||
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { | ||
// If output tensor type is INT64, TensorRT processes data as INT32 and the output will be converted to INT64. | ||
int* output = new int32_t[output_size]; | ||
CHECK_CUDA(cudaMemcpy(output, buffers[i + num_binding_inputs], output_size * sizeof(int32_t), cudaMemcpyDeviceToHost)); | ||
for (int j = 0; j < output_size; ++j) { | ||
ort.GetTensorMutableData<int64_t>(output_tensor)[j] = output[j]; | ||
} | ||
delete[] output; | ||
} else { | ||
Status(common::ONNXRUNTIME, common::FAIL, "Output type is not supported by TensorRT"); | ||
Status(common::ONNXRUNTIME, common::FAIL, "Output tensor type " + std::to_string(output_types[i]) + " is not supported."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same Status() comment |
||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thought we should avoid throwing exception in compute function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see other places in Compile() function where it uses ORT_ENFORCE instead of returning a Status
please revisit.