@@ -67,3 +67,72 @@ TEST(Converters, ATenBMMConvertsCorrectly) {
67
67
68
68
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
69
69
}
70
+
71
+ TEST (Converters, ATenBADDBMMConvertsCorrectly) {
72
+ const auto graph = R"IR(
73
+ graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
74
+ %a : float = prim::Constant[value=1.5]()
75
+ %b : float = prim::Constant[value=.2]()
76
+ %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
77
+ return (%2))IR" ;
78
+
79
+ auto g = std::make_shared<torch::jit::Graph>();
80
+ torch::jit::parseIR (graph, g.get ());
81
+
82
+ auto self = at::randn ({10 , 3 , 5 }, {at::kCUDA });
83
+ auto bat1 = at::randn ({10 , 3 , 4 }, {at::kCUDA });
84
+ auto bat2 = at::randn ({10 , 4 , 5 }, {at::kCUDA });
85
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
86
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {self, bat1, bat2});
87
+
88
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
89
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {self, bat1, bat2});
90
+
91
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
92
+ }
93
+
94
+ TEST (Converters, ATenBADDBMMAlphaBetaDisabledConvertsCorrectly) {
95
+ const auto graph = R"IR(
96
+ graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
97
+ %a : float = prim::Constant[value=1]()
98
+ %b : float = prim::Constant[value=0]()
99
+ %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
100
+ return (%2))IR" ;
101
+
102
+ auto g = std::make_shared<torch::jit::Graph>();
103
+ torch::jit::parseIR (graph, g.get ());
104
+
105
+ auto self = at::randn ({10 , 3 , 5 }, {at::kCUDA });
106
+ auto bat1 = at::randn ({10 , 3 , 4 }, {at::kCUDA });
107
+ auto bat2 = at::randn ({10 , 4 , 5 }, {at::kCUDA });
108
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
109
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {self, bat1, bat2});
110
+
111
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
112
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {self, bat1, bat2});
113
+
114
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
115
+ }
116
+
117
+ TEST (Converters, ATenBADDBMMScalarDefaultsConvertsCorrectly) {
118
+ const auto graph = R"IR(
119
+ graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
120
+ %a : float = prim::Constant[value=1]()
121
+ %b : float = prim::Constant[value=1]()
122
+ %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
123
+ return (%2))IR" ;
124
+
125
+ auto g = std::make_shared<torch::jit::Graph>();
126
+ torch::jit::parseIR (graph, g.get ());
127
+
128
+ auto self = at::randn ({10 , 3 , 5 }, {at::kCUDA });
129
+ auto bat1 = at::randn ({10 , 3 , 4 }, {at::kCUDA });
130
+ auto bat2 = at::randn ({10 , 4 , 5 }, {at::kCUDA });
131
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
132
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {self, bat1, bat2});
133
+
134
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
135
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {self, bat1, bat2});
136
+
137
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
138
+ }
0 commit comments