Skip to content

Commit 78b571c

Browse files
authored
Merge pull request #1806 from mfeliz-cruise/michael.feliz/baddbmm
feat: add support for aten::baddbmm
2 parents 745af55 + a4e55da commit 78b571c

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

core/conversion/converters/impl/matrix_multiply.cpp

+79
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,85 @@ auto mm_registrations TORCHTRT_UNUSED =
7272
mm_layer->setName(util::node_info(n).c_str());
7373
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
7474

75+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
76+
return true;
77+
}})
78+
.pattern(
79+
{"aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
80+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
81+
auto self = args[0].ITensorOrFreeze(ctx);
82+
auto bat1 = args[1].ITensorOrFreeze(ctx);
83+
auto bat2 = args[2].ITensorOrFreeze(ctx);
84+
nvinfer1::Dims batch1Dims = bat1->getDimensions();
85+
nvinfer1::Dims batch2Dims = bat2->getDimensions();
86+
87+
// check dimensions
88+
TORCHTRT_CHECK(
89+
batch1Dims.nbDims == 3,
90+
"Expected 3-dimensional tensor, but got "
91+
<< batch1Dims.nbDims
92+
<< "-dimensional tensor for argument 'batch1' (while checking arguments for baddbmm)");
93+
TORCHTRT_CHECK(
94+
batch2Dims.nbDims == 3,
95+
"Expected 3-dimensional tensor, but got "
96+
<< batch2Dims.nbDims
97+
<< "-dimensional tensor for argument 'batch2' (while checking arguments for baddbmm)");
98+
TORCHTRT_CHECK(
99+
batch1Dims.d[0] == batch2Dims.d[0],
100+
"Expected tensor to have size " << batch1Dims.d[0] << " at dimension 0, but got size "
101+
<< batch2Dims.d[0]
102+
<< " for argument 'batch2' (while checking arguments for baddbmm)");
103+
TORCHTRT_CHECK(
104+
batch1Dims.d[2] == batch2Dims.d[1],
105+
"Expected tensor to have size " << batch1Dims.d[2] << " at dimension 1, but got size "
106+
<< batch2Dims.d[1]
107+
<< " for argument 'batch2' (while checking arguments for baddbmm)");
108+
109+
auto mm_layer = ctx->net->addMatrixMultiply(
110+
*bat1, nvinfer1::MatrixOperation::kNONE, *bat2, nvinfer1::MatrixOperation::kNONE);
111+
TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication for node: " << *n);
112+
mm_layer->setName((util::node_info(n) + "_matmul").c_str());
113+
114+
auto mm_out = mm_layer->getOutput(0);
115+
116+
auto alpha = args[4].unwrapToScalar();
117+
if (alpha.to<float>() != 1.) {
118+
auto alpha_tensor = scalar_to_tensor(ctx, alpha);
119+
auto alpha_layer = add_elementwise(
120+
ctx,
121+
nvinfer1::ElementWiseOperation::kPROD,
122+
mm_out,
123+
alpha_tensor,
124+
util::node_info(n) + std::string("_alpha_mul"));
125+
TORCHTRT_CHECK(alpha_layer, "Unable to create alpha_mul layer from node: " << *n);
126+
mm_out = alpha_layer->getOutput(0);
127+
}
128+
129+
auto beta = args[3].unwrapToScalar();
130+
// If beta is 0, then input will be ignored, and nan and inf in it will not be propagated.
131+
if (beta.to<float>() != 0.) {
132+
if (beta.to<float>() != 1.) {
133+
auto beta_tensor = scalar_to_tensor(ctx, beta);
134+
auto beta_layer = add_elementwise(
135+
ctx,
136+
nvinfer1::ElementWiseOperation::kPROD,
137+
self,
138+
beta_tensor,
139+
util::node_info(n) + std::string("_beta_mul"));
140+
TORCHTRT_CHECK(beta_layer, "Unable to create beta_mul layer from node: " << *n);
141+
self = beta_layer->getOutput(0);
142+
}
143+
auto self_add_layer = add_elementwise(
144+
ctx,
145+
nvinfer1::ElementWiseOperation::kSUM,
146+
self,
147+
mm_out,
148+
util::node_info(n) + std::string("_self_add"));
149+
TORCHTRT_CHECK(self_add_layer, "Unable to create self_add layer from node: " << *n);
150+
mm_out = self_add_layer->getOutput(0);
151+
}
152+
153+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_out);
75154
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
76155
return true;
77156
}});

tests/core/conversion/converters/test_matrix_multiply.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,72 @@ TEST(Converters, ATenBMMConvertsCorrectly) {
6767

6868
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
6969
}
70+
71+
TEST(Converters, ATenBADDBMMConvertsCorrectly) {
72+
const auto graph = R"IR(
73+
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
74+
%a : float = prim::Constant[value=1.5]()
75+
%b : float = prim::Constant[value=.2]()
76+
%2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
77+
return (%2))IR";
78+
79+
auto g = std::make_shared<torch::jit::Graph>();
80+
torch::jit::parseIR(graph, g.get());
81+
82+
auto self = at::randn({10, 3, 5}, {at::kCUDA});
83+
auto bat1 = at::randn({10, 3, 4}, {at::kCUDA});
84+
auto bat2 = at::randn({10, 4, 5}, {at::kCUDA});
85+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
86+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2});
87+
88+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
89+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2});
90+
91+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
92+
}
93+
94+
TEST(Converters, ATenBADDBMMAlphaBetaDisabledConvertsCorrectly) {
95+
const auto graph = R"IR(
96+
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
97+
%a : float = prim::Constant[value=1]()
98+
%b : float = prim::Constant[value=0]()
99+
%2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
100+
return (%2))IR";
101+
102+
auto g = std::make_shared<torch::jit::Graph>();
103+
torch::jit::parseIR(graph, g.get());
104+
105+
auto self = at::randn({10, 3, 5}, {at::kCUDA});
106+
auto bat1 = at::randn({10, 3, 4}, {at::kCUDA});
107+
auto bat2 = at::randn({10, 4, 5}, {at::kCUDA});
108+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
109+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2});
110+
111+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
112+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2});
113+
114+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
115+
}
116+
117+
TEST(Converters, ATenBADDBMMScalarDefaultsConvertsCorrectly) {
118+
const auto graph = R"IR(
119+
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
120+
%a : float = prim::Constant[value=1]()
121+
%b : float = prim::Constant[value=1]()
122+
%2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
123+
return (%2))IR";
124+
125+
auto g = std::make_shared<torch::jit::Graph>();
126+
torch::jit::parseIR(graph, g.get());
127+
128+
auto self = at::randn({10, 3, 5}, {at::kCUDA});
129+
auto bat1 = at::randn({10, 3, 4}, {at::kCUDA});
130+
auto bat2 = at::randn({10, 4, 5}, {at::kCUDA});
131+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
132+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2});
133+
134+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
135+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2});
136+
137+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
138+
}

0 commit comments

Comments
 (0)