Skip to content

[feat] Support conversion of scaled_dot_product_attention #2549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ cc_library(
"impl/einsum.cpp",
"impl/element_wise.cpp",
"impl/expand.cpp",
"impl/internal_ops.cpp",
"impl/interpolate.cpp",
"impl/layer_norm.cpp",
"impl/linear.cpp",
Expand Down
46 changes: 46 additions & 0 deletions core/conversion/converters/impl/internal_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Converter for internal op used in unpack_scaled_dot_product_attention
// We don't have visibility to check types during lowering and can't introduce conditionals so do type specific
// specialization here
auto in = args[0].ITensorOrFreeze(ctx);
auto out = in;
if (in->getType() == nvinfer1::DataType::kBOOL) {
auto not_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(not_layer, "Unable to create not layer for attn_bias_from_attn_mask");
not_layer->setName((util::node_info(n) + "_not").c_str());
auto neg_inf = torch::tensor(-std::numeric_limits<float>::infinity());
auto neg_inf_itensor = tensor_to_const(ctx, neg_inf);
auto prod_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
not_layer->getOutput(0),
neg_inf_itensor,
util::node_info(n) + "_mul");
auto add_layer = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kSUM, prod_layer->getOutput(0), in, util::node_info(n) + "_add");
out = add_layer->getOutput(0);
}
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out);
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
LOG_DEBUG("Output tensor type: " << out_tensor->getType());
return true;
}});
} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
17 changes: 16 additions & 1 deletion core/conversion/converters/impl/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
return true;
}});

auto sqrt_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::sqrt(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
if (in->getType() == nvinfer1::DataType::kINT32) {
// unary sqrt layer only supports float inputs
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT, util::node_info(n).c_str());
}
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kSQRT);
TORCHTRT_CHECK(unary_layer, "Unable to create sqrt layer from node: " << *n);
unary_layer->setName(util::node_info(n).c_str());
unary_layer->setOutputType(0, in->getType());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}});

auto isfinite_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::isfinite(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
Expand Down Expand Up @@ -126,7 +142,6 @@ convert(atan, kATAN);
convert(floor, kFLOOR);
convert(log, kLOG);
convert(ceil, kCEIL);
convert(sqrt, kSQRT);
convert(exp, kEXP);
convert(neg, kNEG);
convert(erf, kERF);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
if (lower_info.converting_to_trt_engine) {
passes::RemoveCollectionCast(g);
}
passes::UnpackScaledDotProductAttention(g);
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library(
"unpack_hardswish.cpp",
"unpack_log_softmax.cpp",
"unpack_rsqrt.cpp",
"unpack_scaled_dot_product_attention.cpp",
"unpack_std.cpp",
"unpack_var.cpp",
"view_to_reshape.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
94 changes: 94 additions & 0 deletions core/lowering/passes/unpack_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "torch/csrc/jit/ir/subgraph_matcher.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"
#include "torch/csrc/jit/ir/irparser.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) {
std::string sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal)
return (%out))IR";

std::string unpacked_sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal):
%none : NoneType = prim::Constant()
%1 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=-2]()
%3 : int = aten::size(%query, %1)
%q_size : Long() = prim::NumToTensor(%3)
%sqrt : Tensor = aten::sqrt(%q_size)
%scale_factor : Tensor = aten::reciprocal(%sqrt)
%key_transpose : Tensor = aten::transpose(%key, %2, %1)
%matmul : Tensor = aten::matmul(%query, %key_transpose)
%attn_weight : Tensor = aten::mul(%matmul, %scale_factor)
%softmax : Tensor = aten::softmax(%attn_weight, %1, %none)
%out : Tensor = aten::matmul(%softmax, %value)
return(%out))IR";

std::string unpacked_sdpa_attn_biased_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal):
%none : NoneType = prim::Constant()
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=-2]()
%3 : int = aten::size(%query, %1)
%q_size : Long() = prim::NumToTensor(%3)
%sqrt : Tensor = aten::sqrt(%q_size)
%scale_factor : Tensor = aten::reciprocal(%sqrt)
%key_transpose : Tensor = aten::transpose(%key, %2, %1)
%matmul : Tensor = aten::matmul(%query, %key_transpose)
%attn_weight : Tensor = aten::mul(%matmul, %scale_factor)
%attn_bias : Tensor = trt::attn_bias_from_attn_mask(%attn_mask)
%attn_weight_with_bias : Tensor = aten::add(%attn_weight, %attn_bias, %0)
%softmax : Tensor = aten::softmax(%attn_weight_with_bias, %1, %none)
%out : Tensor = aten::matmul(%softmax, %value)
return(%out))IR";

// rewrite with None attn_mask
torch::jit::SubgraphRewriter sdpa_rewriter;
sdpa_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_pattern);
sdpa_rewriter.runOnGraph(
graph, [](const torch::jit::Match& match, const std::unordered_map<std::string, torch::jit::Value*>&) {
auto is_causal_node = match.anchor->inputs().at(5)->node();
if (is_causal_node->kind() != at::prim::Constant) {
LOG_WARNING("Could not unpack scaled_dot_product_attention with non constant is_causal: " << *is_causal_node);
return false;
}
if (is_causal_node->i(at::attr::value) == 1) {
LOG_WARNING("Could not unpack scaled_dot_product_attention with is_causal = True: " << *is_causal_node);
return false;
}
auto attn_mask_node = match.anchor->inputs().at(3)->node();
if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) {
return false;
}
return true;
});

// rewrite with float/bool attn_mask this uses a custom op to implement the divergent behavior between bool and float
// masks without a conditional
torch::jit::SubgraphRewriter sdpa_attn_mask_rewriter;
sdpa_attn_mask_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_attn_biased_pattern);
sdpa_attn_mask_rewriter.runOnGraph(
graph, [](const torch::jit::Match& match, const std::unordered_map<std::string, torch::jit::Value*>&) {
auto is_causal_node = match.anchor->inputs().at(5)->node();
if (is_causal_node->kind() != at::prim::Constant || is_causal_node->i(at::attr::value) == 1) {
// messages already written in first pass, do not write again
return false;
}
return true;
});
LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
12 changes: 12 additions & 0 deletions core/lowering/register_trt_placeholder_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <limits>
#include "torch/csrc/jit/runtime/custom_operator.h"

namespace torch {
Expand All @@ -14,6 +15,17 @@ RegisterOperators trt_placeholder_ops_reg({
"trt::const(Tensor val) -> Tensor",
[](Stack& stack) { /*noop*/ },
aliasAnalysisFromSchema()),
Operator(
"trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor",
[](Stack& stack) {
auto attn_mask = pop(stack).to<at::Tensor>();
if (attn_mask.scalar_type() == at::kBool) {
attn_mask = attn_mask;
attn_mask.masked_fill_(attn_mask.logical_not(), -std::numeric_limits<float>::infinity());
}
return attn_mask;
},
c10::AliasAnalysisKind::CONSERVATIVE),
});

} // namespace jit
Expand Down
5 changes: 5 additions & 0 deletions tests/core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ converter_test(
name = "test_where",
)

converter_test(
name = "test_scaled_dot_product_attention",
)

test_suite(
name = "converter_tests",
tests = [
Expand Down Expand Up @@ -238,6 +242,7 @@ test_suite(
":test_reduce",
":test_replication_pad",
":test_roll",
":test_scaled_dot_product_attention",
":test_scatter",
":test_select",
":test_shuffle",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor):
%none : NoneType = prim::Constant()
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto query = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto key = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto value = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value});

torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1e-5));
}

TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto query = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto key = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto value = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto attn_mask = at::rand({32, 8, 128, 128}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask});

torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value, attn_mask});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1e-5));
}

TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto query = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto key = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto value = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, at::kCUDA).to(at::kBool);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask});

torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value, attn_mask});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1e-5));
}
15 changes: 15 additions & 0 deletions tests/core/conversion/converters/test_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ TEST(Converters, ATenLogicalNotBoolConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenSqrtIntConvertsCorrectly) {
const auto graph = gen_test_graph("sqrt");
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());
auto in = at::randint(0, 100, {7, 3, 1, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenFiniteConvertsCorrectly) {
const auto graph = gen_test_graph("isfinite");
auto g = std::make_shared<torch::jit::Graph>();
Expand Down