Skip to content

Commit 31a6be3

Browse files
authored
Add Longformer Attention Cuda Op(#5932)
Limitation: Global tokens must be at the beginning of sequence.
1 parent e39e82b commit 31a6be3

File tree

10 files changed

+1552
-0
lines changed

10 files changed

+1552
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "longformer_attention_base.h"
5+
6+
namespace onnxruntime {
7+
namespace contrib {
8+
9+
LongformerAttentionBase::LongformerAttentionBase(const OpKernelInfo& info) {
10+
int64_t num_heads = 0;
11+
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
12+
num_heads_ = static_cast<int>(num_heads);
13+
14+
int64_t window = 0;
15+
ORT_ENFORCE(info.GetAttr("window", &window).IsOK() && window > 0);
16+
window_ = static_cast<int>(window);
17+
}
18+
19+
Status LongformerAttentionBase::CheckInputs(const TensorShape& input_shape,
20+
const TensorShape& weights_shape,
21+
const TensorShape& bias_shape,
22+
const TensorShape& mask_shape,
23+
const TensorShape& global_weights_shape,
24+
const TensorShape& global_bias_shape,
25+
const TensorShape& global_shape) const {
26+
// Input shapes:
27+
// input : (batch_size, sequence_length, hidden_size)
28+
// weights : (hidden_size, 3 * hidden_size)
29+
// bias : (3 * hidden_size)
30+
// mask : (batch_size, sequence_length)
31+
// global_weights : (hidden_size, 3 * hidden_size)
32+
// global_bias : (3 * hidden_size)
33+
// global : (batch_size, sequence_length)
34+
35+
const auto& dims = input_shape.GetDims();
36+
if (dims.size() != 3) {
37+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
38+
dims.size());
39+
}
40+
41+
int batch_size = static_cast<int>(dims[0]);
42+
int sequence_length = static_cast<int>(dims[1]);
43+
int hidden_size = static_cast<int>(dims[2]);
44+
if (sequence_length % (2 * window_) != 0) {
45+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
46+
"Input 'input' dimension 1 should be divisiable by 2W, where W is value of the window attribute.");
47+
}
48+
if (hidden_size % num_heads_ != 0) {
49+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
50+
"Input 'input' dimension 2 should be divisiable by value of the num_heads attribute.");
51+
}
52+
53+
const auto& weights_dims = weights_shape.GetDims();
54+
if (weights_dims.size() != 2) {
55+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' is expected to have 2 dimensions, got ",
56+
weights_dims.size());
57+
}
58+
if (weights_dims[0] != dims[2]) {
59+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
60+
"Input 'weights' dimension 0 should have same length as dimension 2 of input 0");
61+
}
62+
if (weights_dims[1] != 3 * weights_dims[0]) {
63+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' dimension 1 should be 3 times of dimension 0");
64+
}
65+
66+
const auto& bias_dims = bias_shape.GetDims();
67+
if (bias_dims.size() != 1) {
68+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
69+
bias_dims.size());
70+
}
71+
if (bias_dims[0] != weights_dims[1]) {
72+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
73+
"Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
74+
}
75+
76+
const auto& mask_dims = mask_shape.GetDims();
77+
if (mask_dims.size() == 2) {
78+
if (static_cast<int>(mask_dims[0]) != batch_size || static_cast<int>(mask_dims[1]) != sequence_length) {
79+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask' shall have shape batch_size x sequence_length");
80+
}
81+
} else {
82+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask' is expected to have 2 dimensions, got ",
83+
mask_dims.size());
84+
}
85+
86+
const auto& global_weights_dims = global_weights_shape.GetDims();
87+
if (global_weights_dims.size() != 2) {
88+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'global_weights' is expected to have 2 dimensions, got ",
89+
weights_dims.size());
90+
}
91+
if (global_weights_dims[0] != dims[2]) {
92+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
93+
"Input 'global_weights' dimension 0 should have same length as dimension 2 of input 0");
94+
}
95+
if (global_weights_dims[1] != 3 * global_weights_dims[0]) {
96+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'global_weights' dimension 1 should be 3 times of dimension 0");
97+
}
98+
99+
const auto& global_bias_dims = global_bias_shape.GetDims();
100+
if (global_bias_dims.size() != 1) {
101+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'global_bias' is expected to have 1 dimension, got ",
102+
global_bias_dims.size());
103+
}
104+
if (global_bias_dims[0] != global_weights_dims[1]) {
105+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
106+
"Input 'global_bias' dimension 0 should have same length as dimension 1 of input 'global_weights'");
107+
}
108+
109+
const auto& global_dims = global_shape.GetDims();
110+
if (global_dims.size() == 2) {
111+
if (static_cast<int>(global_dims[0]) != batch_size || static_cast<int>(global_dims[1]) != sequence_length) {
112+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'global' shall have shape batch_size x sequence_length");
113+
}
114+
} else {
115+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'global' is expected to have 2 dimensions, got ",
116+
global_dims.size());
117+
}
118+
119+
return Status::OK();
120+
}
121+
122+
} // namespace contrib
123+
} // namespace onnxruntime
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
12+
class LongformerAttentionBase {
13+
protected:
14+
LongformerAttentionBase(const OpKernelInfo& info);
15+
16+
Status CheckInputs(const TensorShape& input_shape,
17+
const TensorShape& weights_shape,
18+
const TensorShape& bias_shape,
19+
const TensorShape& mask_shape,
20+
const TensorShape& global_weights_shape,
21+
const TensorShape& global_bias_shape,
22+
const TensorShape& global_shape) const;
23+
24+
int num_heads_; // Number of attention heads
25+
int window_; // Attention windows length (W). It is half (one-sided) of total window size.
26+
};
27+
28+
} // namespace contrib
29+
} // namespace onnxruntime
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "longformer_attention.h"
5+
#include "core/framework/tensorprotoutils.h"
6+
#include "core/providers/cuda/cuda_common.h"
7+
#include "core/providers/cuda/shared_inc/fpgeneric.h"
8+
#include "longformer_attention_impl.h"
9+
10+
using namespace onnxruntime::cuda;
11+
using namespace ::onnxruntime::common;
12+
using namespace ONNX_NAMESPACE;
13+
14+
namespace onnxruntime {
15+
namespace contrib {
16+
namespace cuda {
17+
18+
#define REGISTER_KERNEL_TYPED(T) \
19+
ONNX_OPERATOR_TYPED_KERNEL_EX( \
20+
LongformerAttention, \
21+
kMSDomain, \
22+
1, \
23+
T, \
24+
kCudaExecutionProvider, \
25+
KernelDefBuilder() \
26+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
27+
LongformerAttention<T>);
28+
29+
REGISTER_KERNEL_TYPED(float)
30+
REGISTER_KERNEL_TYPED(MLFloat16)
31+
32+
template <typename T>
33+
LongformerAttention<T>::LongformerAttention(const OpKernelInfo& info) : CudaKernel(info), LongformerAttentionBase(info) {}
34+
35+
template <typename T>
36+
Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
37+
const Tensor* input = context->Input<Tensor>(0);
38+
const Tensor* weights = context->Input<Tensor>(1);
39+
const Tensor* bias = context->Input<Tensor>(2);
40+
const Tensor* mask = context->Input<Tensor>(3);
41+
const Tensor* global_weights = context->Input<Tensor>(4);
42+
const Tensor* global_bias = context->Input<Tensor>(5);
43+
const Tensor* global_attention = context->Input<Tensor>(6);
44+
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask->Shape(),
45+
global_weights->Shape(), global_bias->Shape(), global_attention->Shape()));
46+
47+
// Input and output shapes:
48+
// Input 0 - input : (batch_size, sequence_length, hidden_size)
49+
// Output 0 - output : (batch_size, sequence_length, hidden_size)
50+
const auto& shape = input->Shape();
51+
int batch_size = static_cast<int>(shape[0]);
52+
int sequence_length = static_cast<int>(shape[1]);
53+
int hidden_size = static_cast<int>(shape[2]);
54+
int head_size = hidden_size / num_heads_;
55+
56+
Tensor* output = context->Output(0, shape);
57+
58+
cublasHandle_t cublas = CublasHandle();
59+
constexpr size_t element_size = sizeof(T);
60+
61+
// Use GEMM for fully connection.
62+
int m = batch_size * sequence_length;
63+
int n = 3 * hidden_size;
64+
int k = hidden_size;
65+
66+
size_t qkv_size = batch_size * sequence_length * 3 * hidden_size * element_size;
67+
auto gemm_buffer = GetScratchBuffer<T>(qkv_size);
68+
69+
typedef typename ToCudaType<T>::MappedType CudaT;
70+
CudaT one = ToCudaType<T>::FromFloat(1.0f);
71+
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
72+
73+
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
74+
auto& device_prop = GetDeviceProp();
75+
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
76+
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
77+
reinterpret_cast<const CudaT*>(bias->template Data<T>()), n,
78+
GetConstOnes<CudaT>(m), 1,
79+
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
80+
81+
// Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
82+
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
83+
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
84+
reinterpret_cast<const CudaT*>(weights->template Data<T>()), n,
85+
reinterpret_cast<const CudaT*>(input->template Data<T>()), k,
86+
&one, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
87+
88+
// TODO: calculate the exact value from global flags.
89+
int max_num_global = sequence_length;
90+
91+
// Fully connection for global projection.
92+
// Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately.
93+
// When there is no global token, need not run glboal GEMM.
94+
auto global_gemm_buffer = GetScratchBuffer<T>(max_num_global > 0 ? qkv_size : 0);
95+
96+
if (max_num_global > 0) {
97+
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
98+
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
99+
reinterpret_cast<const CudaT*>(global_bias->template Data<T>()), n,
100+
GetConstOnes<CudaT>(m), 1,
101+
&zero, reinterpret_cast<CudaT*>(global_gemm_buffer.get()), n, device_prop));
102+
103+
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
104+
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
105+
reinterpret_cast<const CudaT*>(global_weights->template Data<T>()), n,
106+
reinterpret_cast<const CudaT*>(input->template Data<T>()), k,
107+
&one, reinterpret_cast<CudaT*>(global_gemm_buffer.get()), n, device_prop));
108+
}
109+
110+
size_t workSpaceSize = GetLongformerAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, max_num_global);
111+
auto workspace_buffer = GetScratchBuffer<void>(workSpaceSize);
112+
if (!LaunchLongformerAttentionKernel(
113+
device_prop,
114+
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
115+
reinterpret_cast<const CudaT*>(mask->template Data<T>()),
116+
reinterpret_cast<const CudaT*>(global_gemm_buffer.get()),
117+
global_attention->template Data<int>(),
118+
output->template MutableData<T>(),
119+
batch_size,
120+
sequence_length,
121+
num_heads_,
122+
head_size,
123+
window_,
124+
max_num_global,
125+
workspace_buffer.get(),
126+
cublas,
127+
element_size)) {
128+
// Get last error to reset it to cudaSuccess.
129+
CUDA_CALL(cudaGetLastError());
130+
return Status(common::ONNXRUNTIME, common::FAIL);
131+
}
132+
133+
return Status::OK();
134+
}
135+
136+
} // namespace cuda
137+
} // namespace contrib
138+
} // namespace onnxruntime
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
#include "core/providers/cuda/cuda_common.h"
9+
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
10+
11+
namespace onnxruntime {
12+
namespace contrib {
13+
namespace cuda {
14+
15+
using namespace onnxruntime::cuda;
16+
17+
template <typename T>
18+
class LongformerAttention final : public CudaKernel, public LongformerAttentionBase {
19+
public:
20+
LongformerAttention(const OpKernelInfo& info);
21+
Status ComputeInternal(OpKernelContext* context) const override;
22+
};
23+
24+
} // namespace cuda
25+
} // namespace contrib
26+
} // namespace onnxruntime

0 commit comments

Comments
 (0)