|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 | #include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h"
|
| 9 | + |
| 10 | +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" |
| 11 | +#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" |
| 12 | +#include "gc/Dialect/OneDNNGraph/Utils/Utils.h" |
| 13 | + |
9 | 14 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
10 | 15 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
11 | 16 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
@@ -96,22 +101,32 @@ void ConstantSubgraphAnalyser::setToEntryState(
|
96 | 101 | Lattice<InConstantSubgraph> *lattice) {
|
97 | 102 | if (auto blockArg = cast<BlockArgument>(lattice->getPoint())) {
|
98 | 103 | auto parent_op = blockArg.getParentBlock()->getParentOp();
|
99 |
| - auto parent_op_attr = parent_op->getAttrDictionary(); |
100 |
| - std::optional<NamedAttribute> const_args = |
101 |
| - parent_op_attr.getNamed("onednn_graph.const_args"); |
102 |
| - if (const_args.has_value()) { |
103 |
| - ArrayAttr const_args_indexes = |
104 |
| - llvm::dyn_cast<ArrayAttr>(const_args->getValue()); |
105 |
| - for (auto id : const_args_indexes) { |
106 |
| - auto idint = llvm::cast<IntegerAttr>(id).getInt(); |
107 |
| - if (blockArg.getArgNumber() == idint) { |
108 |
| - LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg |
109 |
| - << " is marked as constant\n"); |
110 |
| - propagateIfChanged(lattice, |
111 |
| - lattice->join(InConstantSubgraph(true, true))); |
112 |
| - return; |
113 |
| - } |
114 |
| - } |
| 104 | + // auto parent_op_attr = parent_op->getAttrDictionary(); |
| 105 | + // std::optional<NamedAttribute> const_args = |
| 106 | + // parent_op_attr.getNamed("onednn_graph.const_args"); |
| 107 | + // if (const_args.has_value()) { |
| 108 | + // ArrayAttr const_args_indexes = |
| 109 | + // llvm::dyn_cast<ArrayAttr>(const_args->getValue()); |
| 110 | + // for (auto id : const_args_indexes) { |
| 111 | + // auto idint = llvm::cast<IntegerAttr>(id).getInt(); |
| 112 | + // if (blockArg.getArgNumber() == idint) { |
| 113 | + // LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg |
| 114 | + // << " is marked as constant\n"); |
| 115 | + // propagateIfChanged(lattice, |
| 116 | + // lattice->join(InConstantSubgraph(true, true))); |
| 117 | + // return; |
| 118 | + // } |
| 119 | + // } |
| 120 | + // } |
| 121 | + auto funcOp = cast<func::FuncOp>(parent_op); |
| 122 | + mlir::onednn_graph::LogicalTensorInfo info(funcOp); |
| 123 | + if (info.queryPropertyType(blockArg) == |
| 124 | + mlir::onednn_graph::PropertyType::constant) { |
| 125 | + LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg |
| 126 | + << " is marked as constant\n"); |
| 127 | + propagateIfChanged(lattice, |
| 128 | + lattice->join(InConstantSubgraph(true, true))); |
| 129 | + return; |
115 | 130 | }
|
116 | 131 | propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false)));
|
117 | 132 | } else {
|
|
0 commit comments