@@ -47,44 +47,52 @@ Graph InferAttr(Graph &&ret,
47
47
}
48
48
49
49
// temp space for shape inference.
50
- std::vector<AttrType* > ishape, oshape;
50
+ std::vector<AttrType> ishape, oshape;
51
51
// number of completed nodes
52
52
size_t num_unknown = 0 ;
53
53
for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
54
54
const auto & inode = idx[nid];
55
+ uint32_t num_inputs = inode.inputs .size ();
56
+ uint32_t num_outputs = inode.source ->num_outputs ();
55
57
if (inode.source ->is_variable ()) {
56
58
if (shape_attr_key.length () != 0 && fis_none (rshape[idx.entry_id (nid, 0 )])) {
57
59
auto it = inode.source ->attrs .dict .find (shape_attr_key);
58
60
if (it != inode.source ->attrs .dict .end ()) {
59
- CHECK_EQ (inode. source -> num_outputs () , 1 );
61
+ CHECK_EQ (num_outputs, 1 );
60
62
std::istringstream is (it->second );
61
63
CHECK (is >> rshape[idx.entry_id (nid, 0 )]) << " Invalid attribute" ;
62
64
}
63
65
}
64
66
continue ;
65
67
}
66
- ishape.resize (inode.inputs .size ());
67
- for (uint32_t i = 0 ; i < ishape.size (); ++i) {
68
- ishape[i] = &rshape[idx.entry_id (inode.inputs [i])];
69
- }
70
- oshape.resize (inode.source ->num_outputs ());
71
- for (uint32_t i = 0 ; i < oshape.size (); ++i) {
72
- oshape[i] = &rshape[idx.entry_id (nid, i)];
73
- }
74
68
if (finfer_shape.count (inode.source ->op )) {
69
+ ishape.resize (num_inputs, def_value);
70
+ for (uint32_t i = 0 ; i < ishape.size (); ++i) {
71
+ ishape[i] = rshape[idx.entry_id (inode.inputs [i])];
72
+ }
73
+ oshape.resize (num_outputs, def_value);
74
+ for (uint32_t i = 0 ; i < oshape.size (); ++i) {
75
+ oshape[i] = rshape[idx.entry_id (nid, i)];
76
+ }
75
77
num_unknown +=
76
- !(finfer_shape[inode.source ->op ](inode.source ->attrs , ishape, oshape));
78
+ !(finfer_shape[inode.source ->op ](inode.source ->attrs , &ishape, &oshape));
79
+ for (uint32_t i = 0 ; i < num_inputs; ++i) {
80
+ rshape[idx.entry_id (inode.inputs [i])] = ishape[i];
81
+ }
82
+ for (uint32_t i = 0 ; i < num_outputs; ++i) {
83
+ rshape[idx.entry_id (nid, i)] = oshape[i];
84
+ }
77
85
} else if (is_backward.get (inode.source ->op , false )) {
78
86
// backward operator inference.
79
87
CHECK_GE (inode.control_deps .size (), 1 )
80
88
<< " BackwardOp need to have control_deps to its forward op" ;
81
89
const auto & fnode = idx[inode.control_deps [0 ]];
82
- CHECK_EQ (fnode.inputs .size (), inode. source -> num_outputs () )
90
+ CHECK_EQ (fnode.inputs .size (), num_outputs)
83
91
<< " BackwardOp need to correspond to the forward node" ;
84
92
bool known = true ;
85
93
for (size_t i = 0 ; i < fnode.inputs .size (); ++i) {
86
- *oshape[i ] = rshape[idx.entry_id (fnode.inputs [i])];
87
- if (fis_none (*oshape[i ])) known = false ;
94
+ rshape[idx. entry_id (nid, i) ] = rshape[idx.entry_id (fnode.inputs [i])];
95
+ if (fis_none (rshape[idx. entry_id (nid, i) ])) known = false ;
88
96
}
89
97
num_unknown += !known;
90
98
}
0 commit comments