|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "longformer_attention.h" |
| 5 | +#include "core/framework/tensorprotoutils.h" |
| 6 | +#include "core/providers/cuda/cuda_common.h" |
| 7 | +#include "core/providers/cuda/shared_inc/fpgeneric.h" |
| 8 | +#include "longformer_attention_impl.h" |
| 9 | + |
| 10 | +using namespace onnxruntime::cuda; |
| 11 | +using namespace ::onnxruntime::common; |
| 12 | +using namespace ONNX_NAMESPACE; |
| 13 | + |
| 14 | +namespace onnxruntime { |
| 15 | +namespace contrib { |
| 16 | +namespace cuda { |
| 17 | + |
| 18 | +#define REGISTER_KERNEL_TYPED(T) \ |
| 19 | + ONNX_OPERATOR_TYPED_KERNEL_EX( \ |
| 20 | + LongformerAttention, \ |
| 21 | + kMSDomain, \ |
| 22 | + 1, \ |
| 23 | + T, \ |
| 24 | + kCudaExecutionProvider, \ |
| 25 | + KernelDefBuilder() \ |
| 26 | + .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ |
| 27 | + LongformerAttention<T>); |
| 28 | + |
| 29 | +REGISTER_KERNEL_TYPED(float) |
| 30 | +REGISTER_KERNEL_TYPED(MLFloat16) |
| 31 | + |
| 32 | +template <typename T> |
| 33 | +LongformerAttention<T>::LongformerAttention(const OpKernelInfo& info) : CudaKernel(info), LongformerAttentionBase(info) {} |
| 34 | + |
| 35 | +template <typename T> |
| 36 | +Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const { |
| 37 | + const Tensor* input = context->Input<Tensor>(0); |
| 38 | + const Tensor* weights = context->Input<Tensor>(1); |
| 39 | + const Tensor* bias = context->Input<Tensor>(2); |
| 40 | + const Tensor* mask = context->Input<Tensor>(3); |
| 41 | + const Tensor* global_weights = context->Input<Tensor>(4); |
| 42 | + const Tensor* global_bias = context->Input<Tensor>(5); |
| 43 | + const Tensor* global_attention = context->Input<Tensor>(6); |
| 44 | + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask->Shape(), |
| 45 | + global_weights->Shape(), global_bias->Shape(), global_attention->Shape())); |
| 46 | + |
| 47 | + // Input and output shapes: |
| 48 | + // Input 0 - input : (batch_size, sequence_length, hidden_size) |
| 49 | + // Output 0 - output : (batch_size, sequence_length, hidden_size) |
| 50 | + const auto& shape = input->Shape(); |
| 51 | + int batch_size = static_cast<int>(shape[0]); |
| 52 | + int sequence_length = static_cast<int>(shape[1]); |
| 53 | + int hidden_size = static_cast<int>(shape[2]); |
| 54 | + int head_size = hidden_size / num_heads_; |
| 55 | + |
| 56 | + Tensor* output = context->Output(0, shape); |
| 57 | + |
| 58 | + cublasHandle_t cublas = CublasHandle(); |
| 59 | + constexpr size_t element_size = sizeof(T); |
| 60 | + |
| 61 | + // Use GEMM for fully connection. |
| 62 | + int m = batch_size * sequence_length; |
| 63 | + int n = 3 * hidden_size; |
| 64 | + int k = hidden_size; |
| 65 | + |
| 66 | + size_t qkv_size = batch_size * sequence_length * 3 * hidden_size * element_size; |
| 67 | + auto gemm_buffer = GetScratchBuffer<T>(qkv_size); |
| 68 | + |
| 69 | + typedef typename ToCudaType<T>::MappedType CudaT; |
| 70 | + CudaT one = ToCudaType<T>::FromFloat(1.0f); |
| 71 | + CudaT zero = ToCudaType<T>::FromFloat(0.0f); |
| 72 | + |
| 73 | + // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B. |
| 74 | + auto& device_prop = GetDeviceProp(); |
| 75 | + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( |
| 76 | + cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, |
| 77 | + reinterpret_cast<const CudaT*>(bias->template Data<T>()), n, |
| 78 | + GetConstOnes<CudaT>(m), 1, |
| 79 | + &zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop)); |
| 80 | + |
| 81 | + // Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x B. |
| 82 | + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( |
| 83 | + cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, |
| 84 | + reinterpret_cast<const CudaT*>(weights->template Data<T>()), n, |
| 85 | + reinterpret_cast<const CudaT*>(input->template Data<T>()), k, |
| 86 | + &one, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop)); |
| 87 | + |
| 88 | + // TODO: calculate the exact value from global flags. |
| 89 | + int max_num_global = sequence_length; |
| 90 | + |
| 91 | + // Fully connection for global projection. |
| 92 | + // Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately. |
| 93 | + // When there is no global token, need not run glboal GEMM. |
| 94 | + auto global_gemm_buffer = GetScratchBuffer<T>(max_num_global > 0 ? qkv_size : 0); |
| 95 | + |
| 96 | + if (max_num_global > 0) { |
| 97 | + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( |
| 98 | + cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, |
| 99 | + reinterpret_cast<const CudaT*>(global_bias->template Data<T>()), n, |
| 100 | + GetConstOnes<CudaT>(m), 1, |
| 101 | + &zero, reinterpret_cast<CudaT*>(global_gemm_buffer.get()), n, device_prop)); |
| 102 | + |
| 103 | + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( |
| 104 | + cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, |
| 105 | + reinterpret_cast<const CudaT*>(global_weights->template Data<T>()), n, |
| 106 | + reinterpret_cast<const CudaT*>(input->template Data<T>()), k, |
| 107 | + &one, reinterpret_cast<CudaT*>(global_gemm_buffer.get()), n, device_prop)); |
| 108 | + } |
| 109 | + |
| 110 | + size_t workSpaceSize = GetLongformerAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, max_num_global); |
| 111 | + auto workspace_buffer = GetScratchBuffer<void>(workSpaceSize); |
| 112 | + if (!LaunchLongformerAttentionKernel( |
| 113 | + device_prop, |
| 114 | + reinterpret_cast<const CudaT*>(gemm_buffer.get()), |
| 115 | + reinterpret_cast<const CudaT*>(mask->template Data<T>()), |
| 116 | + reinterpret_cast<const CudaT*>(global_gemm_buffer.get()), |
| 117 | + global_attention->template Data<int>(), |
| 118 | + output->template MutableData<T>(), |
| 119 | + batch_size, |
| 120 | + sequence_length, |
| 121 | + num_heads_, |
| 122 | + head_size, |
| 123 | + window_, |
| 124 | + max_num_global, |
| 125 | + workspace_buffer.get(), |
| 126 | + cublas, |
| 127 | + element_size)) { |
| 128 | + // Get last error to reset it to cudaSuccess. |
| 129 | + CUDA_CALL(cudaGetLastError()); |
| 130 | + return Status(common::ONNXRUNTIME, common::FAIL); |
| 131 | + } |
| 132 | + |
| 133 | + return Status::OK(); |
| 134 | +} |
| 135 | + |
| 136 | +} // namespace cuda |
| 137 | +} // namespace contrib |
| 138 | +} // namespace onnxruntime |
0 commit comments