Skip to content

Commit fafe9bb

Browse files
eqywweic
authored andcommitted
[RELAY][PASS] detect depthwise conv2d in mac_count pass (apache#3083)
* check in * use groups * CHECK_EQ * trigger CI * Update mac_count.cc * trigger CI * trigger CI
1 parent 8bc2df1 commit fafe9bb

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

src/relay/pass/mac_count.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
#include <tvm/relay/op.h>
3131
#include <tvm/relay/attrs/nn.h>
3232
#include <tvm/relay/expr_functor.h>
33+
#include <tvm/relay/pass.h>
3334
#include <tvm/data_layout.h>
35+
#include "pattern_util.h"
3436

3537
namespace tvm {
3638
namespace relay {
@@ -65,26 +67,29 @@ int64_t ConvMacCount(const Call& call_node) {
6567
}
6668
Array<Expr> args = call_node->args;
6769
CHECK(args.size() == 2)
68-
<< "The number of input arguments of a CONV 2D node should be 2.";
70+
<< "The number of input arguments of a CONV 2D node should be 2.";
6971
const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
7072
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
7173
Array<IndexExpr> data_shape = data_type->shape;
7274
std::string data_layout = conv_2d_attr->data_layout;
7375
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
7476
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
7577
CHECK(C_ind != -1)
76-
<< "There is no input channel dimension.";
78+
<< "There is no input channel dimension.";
7779
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
7880
if (c_ind != -1)
7981
input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
8082
Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
8183
CHECK(kernel_size.size() == 2)
82-
<< "The dimension of the kernel size in Conv 2D should be 2.";
84+
<< "The dimension of the kernel in Conv 2D should be 2.";
8385
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
8486
Array<IndexExpr> output_tensor = expr->shape;
8587
CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
86-
<< "The dimension of the output tensor in Conv 2D should be 4 or 5.";
87-
int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
88+
<< "The dimension of the output tensor in Conv 2D should be 4 or 5.";
89+
int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
90+
CHECK_EQ(input_channel % conv_2d_attr->groups, 0)
91+
<< "The number of input channels is not divisble by groups.";
92+
count *= input_channel/conv_2d_attr->groups;
8893
return count;
8994
}
9095

tests/python/relay/test_pass_mac_count.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Unit tests for MAC counter."""
18+
import numpy as np
1819
import tvm
1920
from tvm import relay
2021

@@ -99,7 +100,35 @@ def test_simple_network():
99100
expect_count = 231411712
100101
assert compute_count == expect_count
101102

103+
def test_depthwise_conv2d():
104+
batch_size = 1
105+
dshape = (batch_size, 64, 56, 56)
106+
weight_conv = relay.var("weight_depthwiseconv", shape=(64, 1, 3, 3))
107+
data1 = relay.var("data1", shape=dshape)
108+
data2 = relay.var("data2", shape=dshape)
109+
depthwise_conv2d_1 = relay.nn.conv2d(
110+
data1,
111+
weight_conv,
112+
kernel_size=(3, 3),
113+
padding=(1, 1),
114+
groups=64)
115+
depthwise_conv2d_2 = relay.nn.conv2d(
116+
data2,
117+
weight_conv,
118+
kernel_size=(3, 3),
119+
padding=(1, 1),
120+
groups=64)
121+
add = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
122+
func = relay.Function([data1, data2, weight_conv],
123+
relay.Tuple(tvm.convert([depthwise_conv2d_1,
124+
depthwise_conv2d_2,
125+
add])))
126+
func = relay.ir_pass.infer_type(func)
127+
compute_count = relay.ir_pass.get_total_mac_number(func)
128+
assert compute_count == 2 * np.prod(dshape) * 3*3
129+
102130
if __name__ == "__main__":
103131
test_conv()
104132
test_gemm()
105133
test_simple_network()
134+
test_depthwise_conv2d()

0 commit comments

Comments
 (0)