Skip to content

[Bug] onnxruntime-gpu 1.16.3 not thread-safe with BERT onnx model in fp16 using CUDA provider #18854

Open

Description

Describe the issue

I am working with language models, and I encountered this specific use-case when working with a BERT model I exported from PyTorch (https://huggingface.co/dicta-il/dictabert).
I initialize an InferenceSession object with my model, and then try to run multiple inputs through in parallel. When I try to initialize the full version of the model it works just fine, but when I initialize the fp16 version of the model (created using onnxconverter_common.float16.convert_float_to_float16) then the multi-threading crashes with various errors, all having to do with illegal memory access. Here are a few example errors:

[E:onnxruntime:, sequential_executor.cc:514 onnxruntime::ExecuteKernel] Non-zero status code returned while running Attention node. Name:'Attention_0' Status Message: CUBLAS failure 14: CUBLAS_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=DESKTOP-9FIMDJR ; file=D:\a\_work\1\s\onnxruntime\contrib_ops\cuda\bert\attention.cc ; line=249 ; expr=cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast<const CudaT*>(weights->Data<T>()), n, reinterpret_cast<const CudaT*>(input->Data<T>()), k, &zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop);
[E:onnxruntime:, sequential_executor.cc:514 onnxruntime::ExecuteKernel] Non-zero status code returned while running MatMul node. Name:'/encoder/layer.0/intermediate/dense/MatMul' Status Message: CUBLAS failure 14: CUBLAS_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=DESKTOP-9FIMDJR ; file=D:\a\_work\1\s\onnxruntime\core\providers\cuda\math\matmul.cc ; line=169 ; expr=cublasGemmHelper( GetCublasHandle(ctx), transB, transA, static_cast<int>(helper.N()), static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha, reinterpret_cast<const CudaT*>(right_X->Data<T>()), ldb, reinterpret_cast<const CudaT*>(left_X->Data<T>()), lda, &zero, reinterpret_cast<CudaT*>(Y->MutableData<T>()), ldc, device_prop);

To reproduce

You can download the onnx model from here.

// Initialize the session
var options = new SessionOptions {
    ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
    LogVerbosityLevel = (int)OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO,
};
options.AppendExecutionProvider_CUDA(0);
var session = new InferenceSession("model.fp16.onnx", options);

// Create sample inputs (vocabulary size is 128,000)
var inputIds = Enumerable.Range(0, 27 * 52).Select(_ => Random.Shared.NextInt64(3, 128000)).ToArray();
var mask = Enumerable.Repeat(1L, 27 * 52).ToArray();
var tokenTypeIds = Enumerable.Repeat(0L, 27 * 52).ToArray();

// Try running multiple inputs through in parallel
Enumerable.Range(0, 500).AsParallel().ForAll(_ => {    
    var output = session.Run(new[] {
        NamedOnnxValue.CreateFromTensor("input_ids", new DenseTensor<long>(inputIds, new[] { 27, 52 })),
        NamedOnnxValue.CreateFromTensor("token_type_ids", new DenseTensor<long>(mask, new[] { 27, 52 })),
        NamedOnnxValue.CreateFromTensor("attention_mask", new DenseTensor<long>(tokenTypeIds, new[] { 27, 52 })),
    });
});

Urgency

We are counting on this feature for deploying a product in our organization. It is not overly urgent, but it would be very helpful to be able to execute inference runs in parallel with the same model.

Platform

Windows

OS Version

19045.2965

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.3

ONNX Runtime API

C#

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

ep:CUDAissues related to the CUDA execution providermodel:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.platform:windowsissues related to the Windows platform

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions