Skip to content

Commit e0d286a

Browse files
icemelonzhiics
authored andcommitted
[Relay][Pass] Count MAC for BatchMatMul (#4157)
* count MAC for BatchMatMul * update doc
1 parent d660e51 commit e0d286a

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

src/relay/pass/mac_count.cc

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,21 @@ int64_t ConvMacCount(const Call& call_node) {
6666
return 0;
6767
}
6868
Array<Expr> args = call_node->args;
69-
CHECK(args.size() == 2)
69+
CHECK_EQ(args.size(), 2)
7070
<< "The number of input arguments of a CONV 2D node should be 2.";
7171
const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
7272
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
7373
Array<IndexExpr> data_shape = data_type->shape;
7474
std::string data_layout = conv_2d_attr->data_layout;
7575
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
7676
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
77-
CHECK(C_ind != -1)
77+
CHECK_NE(C_ind, -1)
7878
<< "There is no input channel dimension.";
7979
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
8080
if (c_ind != -1)
8181
input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
8282
Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
83-
CHECK(kernel_size.size() == 2)
83+
CHECK_EQ(kernel_size.size(), 2)
8484
<< "The dimension of the kernel in Conv 2D should be 2.";
8585
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
8686
Array<IndexExpr> output_tensor = expr->shape;
@@ -99,21 +99,21 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
9999
return 0;
100100
}
101101
Array<Expr> args = call_node->args;
102-
CHECK(args.size() == 2)
102+
CHECK_EQ(args.size(), 2)
103103
<< "The number of input arguments of a CONV 2D Transpose node should be 2.";
104104
const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
105105
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
106106
Array<IndexExpr> data_shape = data_type->shape;
107107
std::string data_layout = conv_2d_transpose_attr->data_layout;
108108
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
109109
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
110-
CHECK(C_ind != -1)
110+
CHECK_NE(C_ind, -1)
111111
<< "There is no input channel dimension.";
112112
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
113113
if (c_ind != -1)
114114
input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
115115
Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
116-
CHECK(kernel_size.size() == 2)
116+
CHECK_EQ(kernel_size.size(), 2)
117117
<< "The dimension of the kernel in Conv 2D Transpose should be 2.";
118118
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
119119
Array<IndexExpr> output_tensor = expr->shape;
@@ -132,7 +132,7 @@ int64_t DenseMacCount(const Call& call_node) {
132132
return 0;
133133
}
134134
Array<Expr> args = call_node->args;
135-
CHECK(args.size() == 2)
135+
CHECK_EQ(args.size(), 2)
136136
<< "The number of input arguments of a Dense node should be 2.";
137137
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
138138
const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
@@ -144,12 +144,28 @@ int64_t DenseMacCount(const Call& call_node) {
144144
int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
145145
int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
146146
int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
147-
CHECK(d2 == d4)
147+
CHECK_EQ(d2, d4)
148148
<< "The dimensions of input arguments do not match.";
149149
int64_t count = d1 * d2 * d3;
150150
return count;
151151
}
152152

153+
int64_t BatchMatmulMacCount(const Call& call_node) {
154+
if (!call_node->checked_type_.defined()) {
155+
LOG(WARNING) << "The infer type pass should be called before the mac count pass";
156+
return 0;
157+
}
158+
Array<Expr> args = call_node->args;
159+
CHECK_EQ(args.size(), 2);
160+
Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
161+
Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
162+
int64_t batch = x_shape[0].as<IntImm>()->value;
163+
int64_t m = x_shape[1].as<IntImm>()->value;
164+
int64_t k = x_shape[2].as<IntImm>()->value;
165+
int64_t n = y_shape[1].as<IntImm>()->value;
166+
return batch * m * k * n;
167+
}
168+
153169
RELAY_REGISTER_OP("nn.conv2d")
154170
.set_attr<FMacCount>("FMacCount", ConvMacCount);
155171

@@ -159,14 +175,17 @@ RELAY_REGISTER_OP("nn.conv2d_transpose")
159175
RELAY_REGISTER_OP("nn.dense")
160176
.set_attr<FMacCount>("FMacCount", DenseMacCount);
161177

178+
RELAY_REGISTER_OP("nn.batch_matmul")
179+
.set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
180+
162181
class MacCounter : private ExprVisitor {
163182
public:
164183
MacCounter() {
165184
count_ = 0;
166185
}
167186
static int64_t GetTotalMacNumber(const Expr& expr) {
168-
LOG(INFO) << "This pass only counts MACs in direct CONV 2D, "
169-
<< "CONV 2D Transpose and Dense ops";
187+
LOG(INFO) << "This pass only counts MACs in direct conv2d, "
188+
<< "conv2d_transpose, dense, and batch_matmul ops";
170189
MacCounter counter;
171190
counter(expr);
172191
return counter.count_;

0 commit comments

Comments
 (0)