Skip to content

Commit 37c7247

Browse files
Minor enhancements
1 parent f3b57a6 commit 37c7247

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

src/dnnl/JsonParser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ JsonParser::parse(llvm::SmallVector<size_t> &outputIds,
107107
auto ret = _builder.create<mlir::func::ReturnOp>(_loc, outputs);
108108

109109
// Copying the strides for the inputs and outputs.
110-
for (auto &ids : {_inputIds, outputIds}) {
111-
for (auto id : ids) {
110+
for (auto &ids : {&_inputIds, &outputIds}) {
111+
for (auto id : *ids) {
112112
auto entry = _strides.find(id);
113113
if (entry != _strides.end()) {
114114
strides[id] = entry->second;

src/dnnl/dnnl_graph_compiler.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,28 @@ struct dnnl_graph_compiler_executable {
7474
std::make_pair(&outputIds, outputs)}) {
7575
auto ids = pair.first;
7676
auto tensors = pair.second;
77-
for (auto id : *ids) {
78-
bool found = false;
79-
for (size_t i = 0; i < ids->size(); i++) {
80-
if (tensors[i].id == id) {
81-
auto s = strides.find(id);
82-
memRefs.emplace_back(&tensors[i],
83-
s == strides.end() ? nullptr : &s->second);
84-
found = true;
85-
break;
77+
for (size_t i = 0, n = ids->size(); i < n; i++) {
78+
auto id = (*ids)[i];
79+
dnnl_graph_compiler_tensor *tensor;
80+
81+
if (tensors[i].id == id) {
82+
tensor = &tensors[i];
83+
} else {
84+
// The order of inputs/outputs may not match the function args order.
85+
tensor = nullptr;
86+
for (size_t j = 0; j < n; j++) {
87+
if (tensors[j].id == id) {
88+
tensor = &tensors[i];
89+
break;
90+
}
91+
}
92+
if (!tensor) {
93+
throw std::invalid_argument("Tensor not found");
8694
}
8795
}
88-
if (!found) {
89-
throw std::invalid_argument("Tensor not found");
90-
}
96+
97+
auto s = strides.find((*ids)[i]);
98+
memRefs.emplace_back(tensor, s == strides.end() ? nullptr : &s->second);
9199
}
92100
}
93101

0 commit comments

Comments
 (0)