Skip to content

Commit 22c3d76

Browse files
committed
Adapt to constant PropertyType
1 parent 3f34e97 commit 22c3d76

File tree

4 files changed

+55
-29
lines changed

4 files changed

+55
-29
lines changed

lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp

+31-16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
//
77
//===----------------------------------------------------------------------===//
88
#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h"
9+
10+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
11+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h"
12+
#include "gc/Dialect/OneDNNGraph/Utils/Utils.h"
13+
914
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1015
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
1116
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -96,22 +101,32 @@ void ConstantSubgraphAnalyser::setToEntryState(
96101
Lattice<InConstantSubgraph> *lattice) {
97102
if (auto blockArg = cast<BlockArgument>(lattice->getPoint())) {
98103
auto parent_op = blockArg.getParentBlock()->getParentOp();
99-
auto parent_op_attr = parent_op->getAttrDictionary();
100-
std::optional<NamedAttribute> const_args =
101-
parent_op_attr.getNamed("onednn_graph.const_args");
102-
if (const_args.has_value()) {
103-
ArrayAttr const_args_indexes =
104-
llvm::dyn_cast<ArrayAttr>(const_args->getValue());
105-
for (auto id : const_args_indexes) {
106-
auto idint = llvm::cast<IntegerAttr>(id).getInt();
107-
if (blockArg.getArgNumber() == idint) {
108-
LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg
109-
<< " is marked as constant\n");
110-
propagateIfChanged(lattice,
111-
lattice->join(InConstantSubgraph(true, true)));
112-
return;
113-
}
114-
}
104+
// auto parent_op_attr = parent_op->getAttrDictionary();
105+
// std::optional<NamedAttribute> const_args =
106+
// parent_op_attr.getNamed("onednn_graph.const_args");
107+
// if (const_args.has_value()) {
108+
// ArrayAttr const_args_indexes =
109+
// llvm::dyn_cast<ArrayAttr>(const_args->getValue());
110+
// for (auto id : const_args_indexes) {
111+
// auto idint = llvm::cast<IntegerAttr>(id).getInt();
112+
// if (blockArg.getArgNumber() == idint) {
113+
// LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg
114+
// << " is marked as constant\n");
115+
// propagateIfChanged(lattice,
116+
// lattice->join(InConstantSubgraph(true, true)));
117+
// return;
118+
// }
119+
// }
120+
// }
121+
auto funcOp = cast<func::FuncOp>(parent_op);
122+
mlir::onednn_graph::LogicalTensorInfo info(funcOp);
123+
if (info.queryPropertyType(blockArg) ==
124+
mlir::onednn_graph::PropertyType::constant) {
125+
LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg
126+
<< " is marked as constant\n");
127+
propagateIfChanged(lattice,
128+
lattice->join(InConstantSubgraph(true, true)));
129+
return;
115130
}
116131
propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false)));
117132
} else {

lib/gc/Transforms/ConstantTensorFolding.cpp

+19-8
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
#include <deque>
1515
#include <unordered_set>
1616

17-
#include "mlir/Transforms/Passes.h"
17+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
18+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h"
19+
#include "gc/Dialect/OneDNNGraph/Utils/Utils.h"
1820

21+
#include "mlir/Transforms/Passes.h"
1922
#include "mlir/Dialect/Arith/IR/Arith.h"
2023
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2124
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
@@ -386,14 +389,22 @@ static void addGlobalI32Array(ModuleOp &module, Location loc,
386389
}
387390

388391
std::unordered_set<int> getConstArgsIndexes(Operation &topFunc) {
389-
auto topFuncAttr = topFunc.getAttrDictionary();
390-
std::optional<NamedAttribute> constArgs =
391-
topFuncAttr.getNamed("onednn_graph.const_args");
392392
std::unordered_set<int> constArgsIndexes;
393-
if (constArgs.has_value()) {
394-
ArrayAttr constArgsArray = llvm::dyn_cast<ArrayAttr>(constArgs->getValue());
395-
for (auto id : constArgsArray) {
396-
constArgsIndexes.insert(llvm::cast<IntegerAttr>(id).getInt());
393+
// auto topFuncAttr = topFunc.getAttrDictionary();
394+
// std::optional<NamedAttribute> constArgs =
395+
// topFuncAttr.getNamed("onednn_graph.const_args");
396+
// if (constArgs.has_value()) {
397+
// ArrayAttr constArgsArray = llvm::dyn_cast<ArrayAttr>(constArgs->getValue());
398+
// for (auto id : constArgsArray) {
399+
// constArgsIndexes.insert(llvm::cast<IntegerAttr>(id).getInt());
400+
// }
401+
// }
402+
auto funcOp = cast<func::FuncOp>(topFunc);
403+
mlir::onednn_graph::LogicalTensorInfo info(funcOp);
404+
for (int i = 0; i < funcOp.getArguments().size(); ++i) {
405+
if (info.queryPropertyType(funcOp.getArguments()[i]) ==
406+
mlir::onednn_graph::PropertyType::constant) {
407+
constArgsIndexes.insert(i);
397408
}
398409
}
399410
return constArgsIndexes;

test/gc/Transforms/test_constant_tensor_folding-1.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// CHECK-LABEL: func.func @entry
44
module {
5-
func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } {
5+
func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32], onednn_graph.property_types = [#onednn_graph.property_type<constant>, #onednn_graph.property_type<constant>, #onednn_graph.property_type<variable>] } {
66
%c0 = arith.constant 0 : index
77
cpuruntime.printf "HI%zu\n" %c0 : index
88
%ax2 = tensor.empty() : tensor<128xf32>
@@ -36,7 +36,7 @@ module {
3636
// COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32>
3737
// COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32>
3838
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64>
39-
// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
39+
// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32], onednn_graph.property_types = [#onednn_graph.property_type<constant>, #onednn_graph.property_type<constant>, #onednn_graph.property_type<variable>]} {
4040
// COM: %c0 = arith.constant 0 : index
4141
// COM: cpuruntime.printf "HI%zu\0A" %c0 : index
4242
// COM: %0 = tensor.empty() : tensor<128xf32>

test/gc/Transforms/test_constant_tensor_folding.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
module {
1010
// COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear.
1111
// COM: arg3: weight of #2 linear. arg4: bias of #2 linear.
12-
func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} {
12+
func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32], onednn_graph.property_types = [#onednn_graph.property_type<variable>, #onednn_graph.property_type<constant>, #onednn_graph.property_type<constant>, #onednn_graph.property_type<constant>, #onednn_graph.property_type<constant>]} {
1313
%1 = tensor.empty() : tensor<2x16x32x32xbf16>
1414
%packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16>
1515
%2 = tensor.empty() : tensor<8x16x32x32xbf16>
@@ -78,5 +78,5 @@ module {
7878
// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32>
7979
// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32>
8080
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64>
81-
// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]}
82-
// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface}
81+
// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16>
82+
// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>)

0 commit comments

Comments
 (0)