@@ -13,7 +13,7 @@ namespace {
13
13
14
14
template <typename AttrType, typename IsNone>
15
15
Graph InferAttr (Graph &&ret,
16
- const AttrType def_value ,
16
+ const AttrType default_val ,
17
17
const char * infer_name,
18
18
const char * input_name,
19
19
const char * attr_key_name,
@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
23
23
using AttrVector = std::vector<AttrType>;
24
24
const IndexedGraph& idx = ret.indexed_graph ();
25
25
static auto & finfer_shape =
26
- Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
26
+ Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
27
27
static auto & backward_map =
28
28
Op::GetAttr<FBackwardOutToInIndex>(" FBackwardOutToInIndex" );
29
29
// reshape shape vector
30
- AttrVector rshape (idx.num_node_entries (), def_value );
30
+ AttrVector rshape (idx.num_node_entries (), default_val );
31
31
32
32
if (ret.attrs .count (input_name) != 0 ) {
33
33
const AttrVector& shape_args = ret.GetAttr <AttrVector>(input_name);
34
34
CHECK_LE (shape_args.size (), idx.input_nodes ().size ())
35
- << " shape args is more than number of arguments" ;
35
+ << " More provided shapes than number of arguments. " ;
36
36
for (size_t i = 0 ; i < shape_args.size (); ++i) {
37
37
rshape[idx.entry_id (idx.input_nodes ()[i], 0 )] = shape_args[i];
38
38
}
@@ -46,47 +46,54 @@ Graph InferAttr(Graph &&ret,
46
46
ret.attrs .erase (attr_key_name);
47
47
}
48
48
49
- // temp space for shape inference.
49
+ // Temp space for shape inference.
50
50
std::vector<AttrType> ishape, oshape;
51
51
// number of completed nodes
52
52
size_t num_unknown = 0 ;
53
53
for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
54
54
const auto & inode = idx[nid];
55
- uint32_t num_inputs = inode.inputs .size ();
56
- uint32_t num_outputs = inode.source ->num_outputs ();
55
+ const uint32_t num_inputs = inode.inputs .size ();
56
+ const uint32_t num_outputs = inode.source ->num_outputs ();
57
57
if (inode.source ->is_variable ()) {
58
- if (shape_attr_key.length () != 0 && fis_none (rshape[idx.entry_id (nid, 0 )])) {
58
+ // Variable node. No operator. Only one output entry.
59
+ CHECK (inode.source ->op () == nullptr );
60
+ CHECK_EQ (num_outputs, 1 );
61
+ const uint32_t out_ent_id = idx.entry_id (nid, 0 );
62
+ if (shape_attr_key.length () != 0 && fis_none (rshape[out_ent_id])) {
59
63
auto it = inode.source ->attrs .dict .find (shape_attr_key);
60
64
if (it != inode.source ->attrs .dict .end ()) {
61
- CHECK_EQ (num_outputs, 1 );
62
65
std::istringstream is (it->second );
63
- CHECK (is >> rshape[idx. entry_id (nid, 0 ) ]) << " Invalid attribute" ;
66
+ CHECK (is >> rshape[out_ent_id ]) << " Invalid attribute" ;
64
67
}
65
68
}
66
- continue ;
67
- }
68
- if (finfer_shape.count (inode.source ->op ())) {
69
- ishape.resize (num_inputs, def_value);
69
+ } else if (finfer_shape.count (inode.source ->op ())) {
70
+ // Forward operator inference.
71
+ ishape.resize (num_inputs, default_val);
70
72
for (uint32_t i = 0 ; i < ishape.size (); ++i) {
71
73
ishape[i] = rshape[idx.entry_id (inode.inputs [i])];
72
74
}
73
- oshape.resize (num_outputs, def_value );
75
+ oshape.resize (num_outputs, default_val );
74
76
for (uint32_t i = 0 ; i < oshape.size (); ++i) {
75
77
oshape[i] = rshape[idx.entry_id (nid, i)];
76
78
}
77
- num_unknown +=
78
- !(finfer_shape[inode.source ->op ()](inode.source ->attrs , &ishape, &oshape));
79
+ // Call inference function of the operator.
80
+ bool forward_known = finfer_shape[inode.source ->op ()](
81
+ inode.source ->attrs , &ishape, &oshape);
82
+ num_unknown += !forward_known;
83
+ // Save to the result map.
79
84
for (uint32_t i = 0 ; i < num_inputs; ++i) {
80
85
rshape[idx.entry_id (inode.inputs [i])] = ishape[i];
81
86
}
82
87
for (uint32_t i = 0 ; i < num_outputs; ++i) {
83
88
rshape[idx.entry_id (nid, i)] = oshape[i];
84
89
}
85
90
} else if (backward_map.count (inode.source ->op ())) {
86
- // backward operator inference.
91
+ // Backward operator inference.
87
92
CHECK_GE (inode.control_deps .size (), 1 )
88
93
<< " BackwardOp need to have control_deps to its forward op" ;
89
- const auto & fnode = idx[inode.control_deps [0 ]];
94
+ const IndexedGraph::Node& fnode = idx[inode.control_deps [0 ]];
95
+ // Inference the outputs of backward operator (equal to the inputs
96
+ // of its corresponding forward operator).
90
97
std::vector<uint32_t > out_map =
91
98
backward_map[inode.source ->op ()](inode.source ->attrs );
92
99
bool known = true ;
0 commit comments