Skip to content

Commit 0bb4ea6

Browse files
authored
Update BiasGelu fusion and related ops (#23518)
### Description (1) Update BiasGelu fusion to support onnx Gelu-20 Since onnx Gelu-20 supports float/double/bf16/fp16, here we update related ops to support these data types in CUDA and ROCm execution providers: (2) Add double support for Gelu/FastGelu op in CUDA/ROCm execution provider (3) Add BFloat16 support for Gelu ops in CUDA execution provider (4) Add unit tests (5) Update operator documents ### Motivation and Context #23491
1 parent 4dde74a commit 0bb4ea6

File tree

18 files changed

+193
-11
lines changed

18 files changed

+193
-11
lines changed

docs/ContribOperators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
17541754
#### Type Constraints
17551755

17561756
<dl>
1757-
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
1757+
<dt><tt>T</tt> : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)</dt>
17581758
<dd>Constrain input and output types to float or half tensors.</dd>
17591759
</dl>
17601760

docs/OperatorKernels.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,11 +912,11 @@ Do not modify directly.*
912912
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
913913
|DynamicTimeWarping|*in* input:**F**<br> *out* output:**I**|1+|**F** = tensor(float)<br/> **I** = tensor(int32)|
914914
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
915-
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
915+
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
916916
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
917917
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
918918
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
919-
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
919+
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
920920
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
921921
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
922922
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace cuda {
3030
REGISTER_KERNEL_TYPED(float)
3131
REGISTER_KERNEL_TYPED(MLFloat16)
3232
REGISTER_KERNEL_TYPED(BFloat16)
33+
REGISTER_KERNEL_TYPED(double)
3334

3435
using namespace ONNX_NAMESPACE;
3536

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ namespace onnxruntime {
2525
namespace contrib {
2626
namespace cuda {
2727
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);
28+
2829
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
30+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu);
2931
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
32+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
3033
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
3134
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
35+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu);
3236
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
3337
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
3438
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
@@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
154158
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
155159
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
156160
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
157-
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
158161
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
159162
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
160163
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
@@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
234237
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
235238
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
236239
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
240+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu)>,
237241
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
242+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
238243
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
239244
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
240245
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
246+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu)>,
241247
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
242248
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
243249
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
@@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
362368
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
363369
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
364370
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
365-
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
366371
// TransposedMatMul is still here for backward compatibility
367372
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
368373
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,

onnxruntime/contrib_ops/rocm/bert/elementwise.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp<ElementwiseParams<T>> {
6666
}
6767

6868
ELEMENTWISE_FWD_DECL(FastGeLU, float);
69+
ELEMENTWISE_FWD_DECL(FastGeLU, double);
6970
ELEMENTWISE_FWD_DECL(FastGeLU, half);
7071
ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16);
7172

7273
ELEMENTWISE_FWD_DECL(GeLU, float);
74+
ELEMENTWISE_FWD_DECL(GeLU, double);
7375
ELEMENTWISE_FWD_DECL(GeLU, half);
7476
ELEMENTWISE_FWD_DECL(GeLU, BFloat16);
7577

onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"
55

66
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float);
7+
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double);
78
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half);
89
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16);

onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"
55

6+
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double);
67
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float);
78
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half);
89
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16);

onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ namespace contrib {
1111
namespace rocm {
1212
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample);
1313
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu);
14+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu);
1415
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
16+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
1517
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu);
1618
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu);
1719
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
20+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu);
1821
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu);
1922
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
2023
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
@@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
126129
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
127130
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
128131
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
129-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
130132
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
131133
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
132134
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
@@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
173175
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
174176
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample)>,
175177
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu)>,
178+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu)>,
176179
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
180+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
177181
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu)>,
178182
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu)>,
179183
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
184+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu)>,
180185
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu)>,
181186
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
182187
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
@@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
287292
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
288293
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
289294
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu)>,
290-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
291295
// TransposedMatMul is still here for backward compatibility
292296
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
293297
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
14901490
.Input(0, "X", "input tensor", "T")
14911491
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
14921492
.Output(0, "Y", "output tensor", "T")
1493-
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
1493+
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
14941494
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
14951495
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
14961496
// fastgelu(x) =

onnxruntime/core/optimizer/bias_gelu_fusion.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
6161
}
6262

6363
const Node& next_node = (*next_node_itr);
64-
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
64+
65+
bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain);
66+
if (!(is_onnx_gelu ||
67+
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
6568
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
6669
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
6770
continue;
@@ -72,14 +75,20 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
7275
continue;
7376
}
7477

78+
bool is_approximate = is_fast_gelu;
79+
if (is_onnx_gelu) {
80+
const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate");
81+
is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh");
82+
}
83+
7584
if (graph.NodeProducesGraphOutput(node)) {
7685
continue;
7786
}
7887

7988
Node& add_node = node;
8089
Node& gelu_node = const_cast<Node&>(next_node);
8190
std::string op_type = "BiasGelu";
82-
if (is_fast_gelu) op_type = "FastGelu";
91+
if (is_approximate) op_type = "FastGelu";
8392

8493
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
8594
op_type,

0 commit comments

Comments
 (0)