Skip to content

Tensor Allocation Order Bug in TensorRT Int8 #65019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2024

Conversation

eee4017
Copy link
Contributor

@eee4017 eee4017 commented Jun 11, 2024

PR Category

Inference

PR Types

Bug fixes

Description

The provided C++ code snippet contains a bug related to tensor allocation:

for (const auto& it : buffers) {
    phi::DenseTensor temp_tensor;
    temp_tensor.Resize(data_shape);
    data_tensors_.push_back(temp_tensor);
    data_buffers_[input_name] = std::pair<void*, size_t>(
        static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)),
        data_size);
}
  • When temp_tensor is pushed into data_tensors_ using data_tensors_.push_back(temp_tensor);, the tensor’s memory is not yet allocated. The Resize function only changes the tensor's shape without allocating memory.
  • Memory allocation occurs only when mutable_data is called (see PaddlePaddle's source code).
  • The buffer data_buffers_ is set to point to temp_tensor.mutable_data<int16_t>(place), but since the tensor in data_tensors_ has no allocated memory, it leads to an invalid memory reference.
  • At the end of the scope, temp_tensor deallocates, leaving data_tensors_ with an invalid pointer.

With a stream-safe allocator, this bug is not immediately apparent because the allocator does not trigger cudaFree immediately. However, with an async allocator, the bug is detected due to stricter memory management.

To fix this bug, reorder the operations to allocate memory before pushing the tensor into data_tensors_

Copy link

paddle-bot bot commented Jun 11, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Jun 11, 2024
@jeng1220
Copy link
Collaborator

The CI failed but it was NOT related to this PR.

test_linalg_cholesky_inverse .....***Failed

C++ Traceback (most recent call last):
--------------------------------------
paddle::pybind::eager_api_inverse(_object*, _object*, _object*)
inverse_ad_func(paddle::Tensor const&)
paddle::experimental::inverse(paddle::Tensor const&)
phi::KernelImpl<void (*)(phi::GPUContext const&, phi::DenseTensor const&, phi::DenseTensor*), &(void phi::InverseKernel<double, phi::GPUContext>(phi::GPUContext const&, phi::DenseTensor const&, phi::DenseTensor*))>::VariadicCompute(phi::DeviceContext const&, phi::DenseTensor const&, phi::DenseTensor*)
phi::funcs::MatrixInverseFunctor<phi::GPUContext, double>::operator()(phi::GPUContext const&, phi::DenseTensor const&, phi::DenseTensor*)
phi::funcs::MapMatrixInverseFunctor<phi::GPUContext, double>::operator()(phi::GPUContext const&, double const*, double*, int, int)
Eigen::PartialPivLU<Eigen::Matrix<double, -1, -1, 1, -1, -1> >::~PartialPivLU()

FatalError: `Segmentation fault` is detected by the operating system.

Copy link

paddle-ci-bot bot commented Jun 19, 2024

Sorry to inform you that 2218e4a's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@onecatcn onecatcn requested a review from Wangzheee June 21, 2024 08:41
@Wangzheee Wangzheee merged commit 2515789 into PaddlePaddle:develop Jun 21, 2024
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Jun 25, 2024
Co-authored-by: lawrence910426 <lawu@nvidia.com>
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Jun 25, 2024
Co-authored-by: lawrence910426 <lawu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers NVIDIA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants