@@ -66,21 +66,21 @@ int64_t ConvMacCount(const Call& call_node) {
6666 return 0 ;
6767 }
6868 Array<Expr> args = call_node->args ;
69- CHECK (args.size () == 2 )
69+ CHECK_EQ (args.size (), 2 )
7070 << " The number of input arguments of a CONV 2D node should be 2." ;
7171 const auto * conv_2d_attr = call_node->attrs .as <Conv2DAttrs>();
7272 const auto * data_type = args[0 ]->checked_type ().as <TensorTypeNode>();
7373 Array<IndexExpr> data_shape = data_type->shape ;
7474 std::string data_layout = conv_2d_attr->data_layout ;
7575 int32_t C_ind = Layout (data_layout).IndexOf (LayoutAxis::Get (' C' ));
7676 int32_t c_ind = Layout (data_layout).IndexOf (LayoutAxis::Get (' c' ));
77- CHECK (C_ind != -1 )
77+ CHECK_NE (C_ind, -1 )
7878 << " There is no input channel dimension." ;
7979 int64_t input_channel = static_cast <int64_t >(data_shape[C_ind].as <IntImm>()->value );
8080 if (c_ind != -1 )
8181 input_channel *= static_cast <int64_t >(data_shape[c_ind].as <IntImm>()->value );
8282 Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size ;
83- CHECK (kernel_size.size () == 2 )
83+ CHECK_EQ (kernel_size.size (), 2 )
8484 << " The dimension of the kernel in Conv 2D should be 2." ;
8585 const auto * expr = call_node->checked_type ().as <TensorTypeNode>();
8686 Array<IndexExpr> output_tensor = expr->shape ;
@@ -99,21 +99,21 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
9999 return 0 ;
100100 }
101101 Array<Expr> args = call_node->args ;
102- CHECK (args.size () == 2 )
102+ CHECK_EQ (args.size (), 2 )
103103 << " The number of input arguments of a CONV 2D Transpose node should be 2." ;
104104 const auto * conv_2d_transpose_attr = call_node->attrs .as <Conv2DTransposeAttrs>();
105105 const auto * data_type = args[0 ]->checked_type ().as <TensorTypeNode>();
106106 Array<IndexExpr> data_shape = data_type->shape ;
107107 std::string data_layout = conv_2d_transpose_attr->data_layout ;
108108 int32_t C_ind = Layout (data_layout).IndexOf (LayoutAxis::Get (' C' ));
109109 int32_t c_ind = Layout (data_layout).IndexOf (LayoutAxis::Get (' c' ));
110- CHECK (C_ind != -1 )
110+ CHECK_NE (C_ind, -1 )
111111 << " There is no input channel dimension." ;
112112 int64_t input_channel = static_cast <int64_t >(data_shape[C_ind].as <IntImm>()->value );
113113 if (c_ind != -1 )
114114 input_channel *= static_cast <int64_t >(data_shape[c_ind].as <IntImm>()->value );
115115 Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size ;
116- CHECK (kernel_size.size () == 2 )
116+ CHECK_EQ (kernel_size.size (), 2 )
117117 << " The dimension of the kernel in Conv 2D Transpose should be 2." ;
118118 const auto * expr = call_node->checked_type ().as <TensorTypeNode>();
119119 Array<IndexExpr> output_tensor = expr->shape ;
@@ -132,7 +132,7 @@ int64_t DenseMacCount(const Call& call_node) {
132132 return 0 ;
133133 }
134134 Array<Expr> args = call_node->args ;
135- CHECK (args.size () == 2 )
135+ CHECK_EQ (args.size (), 2 )
136136 << " The number of input arguments of a Dense node should be 2." ;
137137 const auto * data_type = args[0 ]->checked_type ().as <TensorTypeNode>();
138138 const auto * weight_type = args[1 ]->checked_type ().as <TensorTypeNode>();
@@ -144,12 +144,28 @@ int64_t DenseMacCount(const Call& call_node) {
144144 int64_t d2 = static_cast <int64_t >(data_shape[1 ].as <IntImm>()->value );
145145 int64_t d3 = static_cast <int64_t >(weight_shape[0 ].as <IntImm>()->value );
146146 int64_t d4 = static_cast <int64_t >(weight_shape[1 ].as <IntImm>()->value );
147- CHECK (d2 == d4)
147+ CHECK_EQ (d2, d4)
148148 << " The dimensions of input arguments do not match." ;
149149 int64_t count = d1 * d2 * d3;
150150 return count;
151151}
152152
153+ int64_t BatchMatmulMacCount (const Call& call_node) {
154+ if (!call_node->checked_type_ .defined ()) {
155+ LOG (WARNING) << " The infer type pass should be called before the mac count pass" ;
156+ return 0 ;
157+ }
158+ Array<Expr> args = call_node->args ;
159+ CHECK_EQ (args.size (), 2 );
160+ Array<IndexExpr> x_shape = args[0 ]->checked_type ().as <TensorTypeNode>()->shape ;
161+ Array<IndexExpr> y_shape = args[1 ]->checked_type ().as <TensorTypeNode>()->shape ;
162+ int64_t batch = x_shape[0 ].as <IntImm>()->value ;
163+ int64_t m = x_shape[1 ].as <IntImm>()->value ;
164+ int64_t k = x_shape[2 ].as <IntImm>()->value ;
165+ int64_t n = y_shape[1 ].as <IntImm>()->value ;
166+ return batch * m * k * n;
167+ }
168+
153169RELAY_REGISTER_OP (" nn.conv2d" )
154170.set_attr<FMacCount>(" FMacCount" , ConvMacCount);
155171
@@ -159,14 +175,17 @@ RELAY_REGISTER_OP("nn.conv2d_transpose")
159175RELAY_REGISTER_OP (" nn.dense" )
160176.set_attr<FMacCount>(" FMacCount" , DenseMacCount);
161177
178+ RELAY_REGISTER_OP (" nn.batch_matmul" )
179+ .set_attr<FMacCount>(" FMacCount" , BatchMatmulMacCount);
180+
162181class MacCounter : private ExprVisitor {
163182 public:
164183 MacCounter () {
165184 count_ = 0 ;
166185 }
167186 static int64_t GetTotalMacNumber (const Expr& expr) {
168- LOG (INFO) << " This pass only counts MACs in direct CONV 2D , "
169- << " CONV 2D Transpose and Dense ops" ;
187+ LOG (INFO) << " This pass only counts MACs in direct conv2d , "
188+ << " conv2d_transpose, dense, and batch_matmul ops" ;
170189 MacCounter counter;
171190 counter (expr);
172191 return counter.count_ ;
0 commit comments