|
30 | 30 | #include <tvm/relay/op.h>
|
31 | 31 | #include <tvm/relay/attrs/nn.h>
|
32 | 32 | #include <tvm/relay/expr_functor.h>
|
| 33 | +#include <tvm/relay/pass.h> |
33 | 34 | #include <tvm/data_layout.h>
|
| 35 | +#include "pattern_util.h" |
34 | 36 |
|
35 | 37 | namespace tvm {
|
36 | 38 | namespace relay {
|
@@ -65,26 +67,29 @@ int64_t ConvMacCount(const Call& call_node) {
|
65 | 67 | }
|
66 | 68 | Array<Expr> args = call_node->args;
|
67 | 69 | CHECK(args.size() == 2)
|
68 |
| - << "The number of input arguments of a CONV 2D node should be 2."; |
| 70 | + << "The number of input arguments of a CONV 2D node should be 2."; |
69 | 71 | const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
|
70 | 72 | const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
|
71 | 73 | Array<IndexExpr> data_shape = data_type->shape;
|
72 | 74 | std::string data_layout = conv_2d_attr->data_layout;
|
73 | 75 | int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
|
74 | 76 | int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
|
75 | 77 | CHECK(C_ind != -1)
|
76 |
| - << "There is no input channel dimension."; |
| 78 | + << "There is no input channel dimension."; |
77 | 79 | int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
|
78 | 80 | if (c_ind != -1)
|
79 | 81 | input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
|
80 | 82 | Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
|
81 | 83 | CHECK(kernel_size.size() == 2)
|
82 |
| - << "The dimension of the kernel size in Conv 2D should be 2."; |
| 84 | + << "The dimension of the kernel in Conv 2D should be 2."; |
83 | 85 | const auto* expr = call_node->checked_type().as<TensorTypeNode>();
|
84 | 86 | Array<IndexExpr> output_tensor = expr->shape;
|
85 | 87 | CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
|
86 |
| - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; |
87 |
| - int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); |
| 88 | + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; |
| 89 | + int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); |
| 90 | + CHECK_EQ(input_channel % conv_2d_attr->groups, 0) |
| 91 | + << "The number of input channels is not divisble by groups."; |
| 92 | + count *= input_channel/conv_2d_attr->groups; |
88 | 93 | return count;
|
89 | 94 | }
|
90 | 95 |
|
|
0 commit comments