Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TensorRT GetCapability to Enable More Models #1012

Merged
merged 12 commits into from
May 24, 2019
51 changes: 38 additions & 13 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
}
//only when input is neither in output list nor erased list, add the input to input list
else if (erased.find(input) == erased.end()) {
} else if (erased.find(input) == erased.end()) {
//only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
}
}
Expand Down Expand Up @@ -462,22 +461,44 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]);
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
const auto& tensor_shape = ort.GetTensorShape(tensor_info);
auto tensor_type = ort.GetTensorElementType(tensor_info);
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
const float* input = ort.GetTensorData<float>(input_tensor);

const int input_batch_size = tensor_shape[0];
if (i > 0 && batch_size != input_batch_size) {
ORT_THROW("Input batch size is inconsistent");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought we should avoid throwing exception in compute function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see other places in Compile() function where it uses ORT_ENFORCE instead of returning a Status
please revisit.

}
batch_size = input_batch_size;

CHECK_CUDA(cudaMalloc(&buffers[i], input_batch_size * input_dim_sizes[i] * sizeof(float)));
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_batch_size * input_dim_sizes[i] * sizeof(float), cudaMemcpyHostToDevice));
int input_size = batch_size * input_dim_sizes[i];
if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
const float* input = const_cast<float*>(ort.GetTensorData<float>(input_tensor));
CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(float)));
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(float), cudaMemcpyHostToDevice));
} else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
const int8_t* input = const_cast<int8_t*>(ort.GetTensorData<int8_t>(input_tensor));
CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(int8_t)));
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(int8_t), cudaMemcpyHostToDevice));
} else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
const int32_t* input = const_cast<int32_t*>(ort.GetTensorData<int32_t>(input_tensor));
CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(int32_t)));
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(int32_t), cudaMemcpyHostToDevice));
} else {
Status(common::ONNXRUNTIME, common::FAIL, "Input tensor type " + std::to_string(tensor_type) + " is not supported.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why allocate a Status() , doesn't seem used anywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment for other Status() allocations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this compute_func() always returns 0, which doens't seem correct.
if there are errors, we need to return 1.

}
}

// Allocate cuda memory for outputs
for (int i = 0, end = num_binding_outputs; i < end; ++i) {
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(float)));
if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(float)));
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(int8_t)));
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 || output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(int32_t)));
} else {
Status(common::ONNXRUNTIME, common::FAIL, "Output tensor type " + std::to_string(output_types[i]) + " is not supported.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for Status()

}
}

// Run TRT inference
Expand All @@ -491,18 +512,22 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:

int output_size = batch_size * output_dim_sizes[i];
OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index, output_shapes[i].data(), output_shapes[i].size());
// If output tensor type is INT64, TensorRT processes data as INT32 and the output will be converted to INT64.
if (output_types[i] == TensorProto::FLOAT) {
if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<float>(output_tensor), buffers[i + num_binding_inputs], output_size * sizeof(float), cudaMemcpyDeviceToHost));
} else if (output_types[i] == TensorProto::INT64) {
int* output = new int[output_size];
CHECK_CUDA(cudaMemcpy(output, buffers[i + num_binding_inputs], output_size * sizeof(int), cudaMemcpyDeviceToHost));
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<int8_t>(output_tensor), buffers[i + num_binding_inputs], output_size * sizeof(int8_t), cudaMemcpyDeviceToHost));
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<int32_t>(output_tensor), buffers[i + num_binding_inputs], output_size * sizeof(int32_t), cudaMemcpyDeviceToHost));
} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
// If output tensor type is INT64, TensorRT processes data as INT32 and the output will be converted to INT64.
int* output = new int32_t[output_size];
CHECK_CUDA(cudaMemcpy(output, buffers[i + num_binding_inputs], output_size * sizeof(int32_t), cudaMemcpyDeviceToHost));
for (int j = 0; j < output_size; ++j) {
ort.GetTensorMutableData<int64_t>(output_tensor)[j] = output[j];
}
delete[] output;
} else {
Status(common::ONNXRUNTIME, common::FAIL, "Output type is not supported by TensorRT");
Status(common::ONNXRUNTIME, common::FAIL, "Output tensor type " + std::to_string(output_types[i]) + " is not supported.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same Status() comment

}
}

Expand Down