Skip to content

Commit

Permalink
GPT2 Gelu Fusion & Test (#3009)
Browse files Browse the repository at this point in the history
* GPT2 Gelu Fusion & Test

* change header path

* Refine code & add missing test onnx file

* Fix builds & refine float/double/fp16 compare.

* Fix builds

* Add Bias Check and UTs

* Fix build and uts

* Fuse with second formula & test

* minor change

* disable FastGelu to see whether the builds can pass

* Verify where is wrong

* disable for debugging

* Revert "disable for debugging"

This reverts commit 535c081.

* Revert "Verify where is wrong"

This reverts commit ffc43ec.

* disable the transformer for inference currently

* Enable FastGeluFusion and fix segement fault when run bertsquad10.onnx test

* Add more Unit tests convering Gelu subgraph use graph input/output

(cherry picked from commit 0739ab985240c6d9acdb8f0afd40c5fb316166af)

* Mode Bias Fusion in BiasGelu.cc

Co-authored-by: Changming Sun <chasun@microsoft.com>
  • Loading branch information
pengwa and snnn authored Feb 21, 2020
1 parent 932ecae commit 92b8a7a
Show file tree
Hide file tree
Showing 17 changed files with 837 additions and 9 deletions.
14 changes: 11 additions & 3 deletions onnxruntime/core/optimizer/bias_gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const
}

const Node& next_node = (*next_node_itr);
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
}

bool is_fast_gelu = next_node.OpType().compare("FastGelu") == 0;
if (is_fast_gelu && next_node.InputDefs().size() > 1) {
continue;
}

if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
continue;
}

Node& add_node = node;
Node& gelu_node = const_cast<Node&>(next_node);
std::string op_type = "BiasGelu";
if (is_fast_gelu) op_type = "FastGelu";

Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName("BiasGelu"),
"BiasGelu",
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
op_type,
"fused Add and Gelu",
gelu_input,
{},
Expand Down
244 changes: 244 additions & 0 deletions onnxruntime/core/optimizer/fast_gelu_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/initializer.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"
#include "float.h"
#include <deque>

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {

// FastGelu supports limited data types.
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};

static bool CheckNode(const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider,
bool require_single_output=false){
return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) &&
node.GetExecutionProviderType() == provider &&
optimizer_utils::IsSupportedDataType(node, supported_data_types) &&
(!require_single_output || node.GetOutputEdgesCount() == 1);
}

MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) ||
!graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) ||
mul1_node.GetOutputEdgesCount() != 1 ||
!optimizer_utils::IsSupportedDataType(mul1_node, supported_data_types)) {
return matchResult;
}

int32_t input_index = -1;
const float mul_val = 0.044715f;
for (auto i = 0; i < 2; i++) {
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[i]), mul_val, true)){
input_index = i;
break;
}
}

if (input_index == -1) return matchResult;

NodeArg* gelu_without_bias_input_arg = mul1_node.MutableInputDefs()[(input_index + 1) % 2];
nodes_to_fuse.push_back(mul1_node);


Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) ||
mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) {
return matchResult;;
}
nodes_to_fuse.push_back(mul2_node);


Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]);
if (!CheckNode(add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(add1_node);


Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
if (!CheckNode(mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul3_node);

input_index = optimizer_utils::IndexOfNodeInput(mul3_node, *add1_node.MutableOutputDefs()[0]);
const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2);
if (p_mul3_input_node == nullptr) return matchResult;
Node& mul4_node = const_cast<Node&>(*p_mul3_input_node);
if (!CheckNode(mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) {
return matchResult;
}

input_index = -1;
const float mul4_val = 0.7978845834732056f;
for (auto i = 0; i < 2; i++) {
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul4_node.InputDefs()[i]), mul4_val, true)){
input_index = i;
break;
}
}

if (input_index == -1 || mul4_node.InputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name())
return matchResult;
nodes_to_fuse.push_back(mul4_node);

matchResult.matched = true;
matchResult.gelu_without_bias_input_arg = gelu_without_bias_input_arg;
matchResult.tanh_input_node = &mul3_node;
return matchResult;
}

MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const {
MatchResult matchResult{false, nullptr, nullptr};
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7}) ||
!graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) ||
pow1_node.GetOutputEdgesCount() != 1 ||
!optimizer_utils::IsSupportedDataType(pow1_node, supported_data_types)) {
return matchResult;
}

if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(pow1_node.InputDefs()[1]), 3.0f, true)){
return matchResult;
}

NodeArg* pow_input_arg = pow1_node.MutableInputDefs()[0];
nodes_to_fuse.push_back(pow1_node);

Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index());
auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]);
if (!CheckNode(mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]),
0.044714998453855515f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul1_node);


Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]);
if (!CheckNode(add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) ||
add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) {
return matchResult;
}
nodes_to_fuse.push_back(add1_node);


Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index());
input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]);
if (!CheckNode(mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) ||
!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]),
0.7978845834732056f, true)) {
return matchResult;
}
nodes_to_fuse.push_back(mul2_node);

matchResult.matched = true;
matchResult.gelu_without_bias_input_arg = pow_input_arg;
matchResult.tanh_input_node = &mul2_node;
return matchResult;
}

Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr)
continue;

Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

std::vector<std::reference_wrapper<Node>> nodes_to_fuse;
MatchResult matchRet = CheckFirstFormula(graph, node, nodes_to_fuse);
if (!matchRet.matched) {
nodes_to_fuse.clear();
matchRet = CheckSecondFormula(graph, node, nodes_to_fuse);

if(!matchRet.matched) continue;
};

Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index());
if (!CheckNode(tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) {
continue;
}


Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index());
if (!CheckNode(add2_node, "Add", 7, node.GetExecutionProviderType(), true)) {
continue;
}

auto input_index = optimizer_utils::IndexOfNodeInput(add2_node, *tanh_node.MutableOutputDefs()[0]);
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add2_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) {
continue;
}

Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index());
// This is the output of the Gelu subgraph, we don't need check it has single edge.
if (!CheckNode(mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
continue;
}

// ingnore the transformer if Gelu's output is the graph's output.
if (!graph.GetNodeOutputsInGraphOutputs(mul5_node).empty()) {
continue;
}

input_index = optimizer_utils::IndexOfNodeInput(mul5_node, *add2_node.MutableOutputDefs()[0]);
const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2);
if (p_mul5_input_node == nullptr) continue;
Node& mul6_node = const_cast<Node&>(*p_mul5_input_node);
if (!CheckNode(mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) {
continue;
}

input_index = -1;
for (auto i = 0; i < 2; i++) {
if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul6_node.InputDefs()[i]), 0.5f, true)){
input_index = i;
break;
}
}

if (input_index == -1 || mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name())
continue;

std::vector<NodeArg*> gelu_input_defs{matchRet.gelu_without_bias_input_arg};
nodes_to_fuse.insert(nodes_to_fuse.end(), {tanh_node, add2_node, mul6_node, mul5_node});

auto type_info = *node.MutableOutputDefs()[0]->TypeAsProto();
auto& shape_output = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("fast_gelu_output"), &type_info);
Node& fast_gelu_node = graph.AddNode(graph.GenerateNodeName("GPT2Gelu"),
"FastGelu",
"fused GPT2Gelu subgraphs ",
gelu_input_defs,
{&shape_output}, {}, kMSDomain);

// assign provider to this new node, provider should be same as the provider for old node.
fast_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());

// move input edges to node (first in list) across to the fast_gelu_node.
// move output definitions and output edges from mul5_node (last in list) to fast_gelu_node.
// remove all nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, fast_gelu_node);

modified = true;
}

return Status::OK();
}
} // namespace onnxruntime
39 changes: 39 additions & 0 deletions onnxruntime/core/optimizer/fast_gelu_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

struct MatchResult {
public:
bool matched;
NodeArg* gelu_without_bias_input_arg; // The Gelu input arg if not considering bias node.
Node* tanh_input_node;
};

/**
@Class FastGeluFusion
Rewrite graph fusing Gelu activation subgraph to a single Gelu node.
The formula corresponding to Gelu activation subgraph:
x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) or
x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))), where x is the input.
*/
class FastGeluFusion : public GraphTransformer {
public:
FastGeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("FastGeluFusion", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

MatchResult CheckFirstFormula(Graph& graph, Node& node, std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const;

MatchResult CheckSecondFormula(Graph& graph, Node& nodes, std::vector<std::reference_wrapper<Node>>& nodes_to_fuse) const;
};

} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "core/optimizer/bias_gelu_fusion.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
Expand Down Expand Up @@ -135,6 +136,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL

std::unordered_set<std::string> cuda_execution_providers = {onnxruntime::kCudaExecutionProvider};
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<FastGeluFusion>(cuda_execution_providers));
#endif
} break;

Expand Down
Loading

0 comments on commit 92b8a7a

Please sign in to comment.