Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TensorRT GetCapability to Enable More Models #1012

Merged
merged 12 commits into from
May 24, 2019
Merged

Conversation

stevenlix
Copy link
Contributor

No description provided.

@stevenlix stevenlix requested a review from jywu-msft May 11, 2019 01:29
@stevenlix stevenlix requested a review from a team as a code owner May 11, 2019 01:29
@jywu-msft
Copy link
Member

seems like windows tensorrt build fails

@jywu-msft
Copy link
Member

if GetCapability() is now reliable, doesn't this mean we can re-enable many of the TRT disabled unit tests?
The ultimate validation would be if the TensorRT Execution provider exclusions in #802 can be removed (or kept at an absolute minimum)

@stevenlix
Copy link
Contributor Author

Yes. I think most of the disabled unit tests could be removed. But if we keep those tests disabled people know what tests can't run on TensorRT, otherwise we just lose the tracking since those tests just fall back to other execution providers.

@jywu-msft
Copy link
Member

We do not want to leave tests disabled. It was only done as a temporary workaround until GetCapability is fixed.
The purpose of the unit tests is to ensure correctness. Falling back via GetCapability is still a correct/valid result. If we leave the tests disabled, there's no validation that the GetCapability() changes in this PR are working as expected. And it prevents future regressions in GetCapability().
Have all unit tests pass (without disabling any tests) gives us confidence that TRT provider can handle real models without crashing.

Yes. I think most of the disabled unit tests could be removed. But if we keep those tests disabled people know what tests can't run on TensorRT, otherwise we just lose the tracking since those tests just fall back to other execution providers.

}


std::unique_ptr<IndexedSubGraph> GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index,
Copy link
Member

@jywu-msft jywu-msft May 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to have explicit unit tests of these new methods (GetSubGraph, GetSupportedList)

std::unique_ptr<IndexedSubGraph> GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index,
const onnxruntime::GraphViewer& graph) const;

SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int& max_iterations,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is max_iterations a reference?


CHECK_CUDA(cudaMemcpy(output_tensors[output_index].data, buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost));
int output_size = batch_size * output_dim_sizes[i];
if (output_types[i] == TensorProto::FLOAT) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments for what we're doing here.

@jywu-msft
Copy link
Member

I see TRTProvider exclusions were removed in gemm_test.cc
what about other tests? there were quite a few enabled in #802
will we be able to remove all the exclusions, or are there still some failures in GetCapability?

@@ -173,7 +174,7 @@ TEST(UpsampleOpTest, UpsampleOpNearest222XTest) {
};

test.AddOutput<float>("Y", {N*2, C, (int64_t)(H * scales[2]), (int64_t)(W * scales[3])}, Y);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});//TensorRT parser: Assertion failed: scales[0] == 1 && scales[1] == 1
test.Run();//TensorRT parser: Assertion failed: scales[0] == 1 && scales[1] == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these assertion failure comments be removed if we are re-enabling the TensorrtExecutionProvider for these tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment for other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the comments

.gitmodules Show resolved Hide resolved
OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index, output_shapes[i].data(), output_shapes[i].size());
if (output_types[i] == TensorProto::FLOAT) {
CHECK_CUDA(cudaMemcpy(ort.GetTensorMutableData<float>(output_tensor), buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost));
// If output tensor type is INT64, TensorRT processes data as INT32 and the output will be converted to INT64.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the comment to where the applicable code is.

ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
const float* input = ort.GetTensorData<float>(input_tensor);

const int input_batch_size = tensor_shape[0];
if (i > 0 && batch_size != input_batch_size) {
ORT_THROW("Input batch size is inconsistent");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought we should avoid throwing exception in compute function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see other places in Compile() function where it uses ORT_ENFORCE instead of returning a Status
please revisit.

CHECK_CUDA(cudaMalloc(&buffers[i], input_size * sizeof(int32_t)));
CHECK_CUDA(cudaMemcpy(buffers[i], input, input_size * sizeof(int32_t), cudaMemcpyHostToDevice));
} else {
Status(common::ONNXRUNTIME, common::FAIL, "Input tensor type " + std::to_string(tensor_type) + " is not supported.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why allocate a Status() , doesn't seem used anywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment for other Status() allocations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this compute_func() always returns 0, which doens't seem correct.
if there are errors, we need to return 1.

} else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 || output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], batch_size * output_dim_sizes[i] * sizeof(int32_t)));
} else {
Status(common::ONNXRUNTIME, common::FAIL, "Output tensor type " + std::to_string(output_types[i]) + " is not supported.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for Status()

for (int j = 0; j < output_size; ++j) {
ort.GetTensorMutableData<int64_t>(output_tensor)[j] = output[j];
}
delete[] output;
} else {
Status(common::ONNXRUNTIME, common::FAIL, "Output type is not supported by TensorRT");
Status(common::ONNXRUNTIME, common::FAIL, "Output tensor type " + std::to_string(output_types[i]) + " is not supported.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same Status() comment

@jywu-msft jywu-msft requested a review from HectorSVC May 23, 2019 03:25
}
delete[] output;
} else {
return 1;
Copy link
Member

@jywu-msft jywu-msft May 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to return specific status code instead of 1 (same for other places)
https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/common/status.h#L33

@jywu-msft
Copy link
Member

looks like you'll need to merge master to pick up #1097
to resolve the nocontribops CI failure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants