Skip to content

Commit a589eb7

Browse files
authored
Copy LinalgToXeGPU pass from tpp-mlir (#165)
* Copy LinalgToXeGPU pass from tpp-mlir * fix clang-format * fix clang-tidy * fix build warnings Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com> * remove unused utils Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com> * licenses Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com> --------- Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
1 parent 1b1ebf2 commit a589eb7

File tree

18 files changed

+2977
-1
lines changed

18 files changed

+2977
-1
lines changed

include/gc/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ namespace MemRef {
3636
class MemRefDialect;
3737
}
3838

39+
namespace xegpu {
40+
class XeGPUDialect;
41+
}
42+
3943
class PassManager;
4044

4145
namespace gc {

include/gc/Transforms/Passes.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,29 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
4646
"vector::VectorDialect"];
4747
}
4848

49+
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
50+
let summary = "Convert linalg dialect to XeGPU dialect.";
51+
let description = [{
52+
Lower linalg ops to XeGPU dialect.
53+
}];
54+
let dependentDialects = ["linalg::LinalgDialect",
55+
"gpu::GPUDialect",
56+
"xegpu::XeGPUDialect",
57+
"scf::SCFDialect",
58+
"memref::MemRefDialect",
59+
"arith::ArithDialect",
60+
"math::MathDialect",
61+
"vector::VectorDialect"];
62+
let options = [
63+
Option<"kTile", "k-tile", "int64_t",
64+
/*default=*/"32",
65+
"GEMM tile size for reduction dimension.">,
66+
Option<"stages", "stages", "int64_t",
67+
/*default=*/"1",
68+
"Number of cooperative prefetch stages.">,
69+
ListOption<"dpasTile", "dpas-tile", "int64_t",
70+
"DPAS register block sizes MxNxK">,
71+
];
72+
}
73+
4974
#endif // GC_DIALECT_GC_PASSES
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===- MatcherUtils.h - Matcher utils ---------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_MATCHERUTILS_H
10+
#define GC_MATCHERUTILS_H
11+
12+
namespace mlir {
13+
class Value;
14+
namespace linalg {
15+
class LinalgOp;
16+
class GenericOp;
17+
} // namespace linalg
18+
namespace structured_match {
19+
namespace utils {
20+
21+
// Returns true if the linalg operation is a 2d eltwsie floating point addition.
22+
bool isTwoDAddOp(linalg::LinalgOp linalgOp,
23+
SmallVectorImpl<Value> *capturedOperands = nullptr);
24+
25+
// Returns true if the linalg.generic is a 2d eltwise floating point relu
26+
// operation.
27+
bool isTwoDReluOp(linalg::LinalgOp linalgOp,
28+
SmallVectorImpl<Value> *capturedOperands = nullptr);
29+
30+
} // namespace utils
31+
} // namespace structured_match
32+
} // namespace mlir
33+
34+
#endif // GC_MATCHERUTILS_H

0 commit comments

Comments
 (0)