Skip to content

Commit 2b277e2

Browse files
committed
Add target description query and verifier pass
1 parent f5bde39 commit 2b277e2

File tree

13 files changed

+440
-25
lines changed

13 files changed

+440
-25
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//===-- TargetDescriptionAnalysis.h - target description class --*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
10+
#define MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
#include "llvm/ADT/StringRef.h"
17+
18+
namespace mlir {
19+
namespace gc {
20+
21+
using namespace mlir;
22+
23+
enum DeviceType { CPU = 0 };
24+
25+
class TargetDescriptionAnalysisBase {
26+
public:
27+
TargetDescriptionAnalysisBase(Operation *op, DeviceType device)
28+
: ctx(op->getContext()), device(device),
29+
layout(isa<ModuleOp>(op) ? dyn_cast<ModuleOp>(op)
30+
: op->getParentOfType<ModuleOp>()),
31+
loc(op->getLoc()) {}
32+
33+
// get the device ID
34+
DeviceType getDevice() { return device; }
35+
36+
// get the MLIR context
37+
MLIRContext *getContext() { return ctx; }
38+
39+
// get the data layout
40+
DataLayout getLayout() { return layout; }
41+
42+
// get the property value by key
43+
std::optional<Attribute> getPropertyValue(StringRef key);
44+
45+
// get the location
46+
Location getLocation() { return loc; }
47+
48+
// check if the property exists
49+
bool hasProperty(StringRef key) { return getPropertyValue(key).has_value(); }
50+
51+
// emit warning if the property is not found
52+
template <typename T>
53+
void emitNotFoundWarning(Location loc, StringRef key, T value);
54+
55+
// the map from device type to device string
56+
static llvm::DenseMap<DeviceType, std::string> DeviceKeyMap;
57+
58+
private:
59+
MLIRContext *ctx;
60+
DeviceType device;
61+
DataLayout layout;
62+
Location loc;
63+
};
64+
65+
class CPUTargetDescriptionAnalysis : public TargetDescriptionAnalysisBase {
66+
public:
67+
static constexpr StringLiteral kL1CacheSize = "L1_cache_size_in_bytes";
68+
static constexpr StringLiteral kL2CacheSize = "L2_cache_size_in_bytes";
69+
static constexpr StringLiteral kL3CacheSize = "L3_cache_size_in_bytes";
70+
static constexpr StringLiteral kMaxVectorWidth = "max_vector_width";
71+
static constexpr StringLiteral kNumThreads = "num_threads";
72+
73+
// get runtime OMP_NUM_THREADS
74+
size_t getNumThreads();
75+
76+
// get cache size by cacheLevel
77+
size_t getCacheSize(uint8_t cacheLevel);
78+
79+
// get the maximum vector length in bits
80+
size_t getMaxVectorWidth();
81+
82+
// get the default value map(attr key, default value)
83+
static llvm::DenseMap<StringRef, int64_t> CPUTargetDeafultValueMap;
84+
85+
CPUTargetDescriptionAnalysis(Operation *op)
86+
: TargetDescriptionAnalysisBase(op, DeviceType::CPU) {}
87+
};
88+
89+
} // namespace gc
90+
} // namespace mlir
91+
92+
#endif

include/gc/Transforms/Passes.td

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,47 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1919

2020
#ifdef GC_HAS_ONEDNN_DIALECT
2121
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
22-
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
23-
let description = [{
24-
Lowers the `onednn_graph` ops to `linalg` ops.
25-
}];
22+
let summary =
23+
"Lower the operations from the oneDNN Graph dialect into Linalg";
24+
let description = [{Lowers the `onednn_graph` ops to `linalg` ops.}];
2625
let dependentDialects = [
27-
"func::FuncDialect",
28-
"math::MathDialect",
29-
"arith::ArithDialect",
30-
"tensor::TensorDialect",
31-
"linalg::LinalgDialect",
32-
"linalgx::LinalgxDialect"
26+
"func::FuncDialect", "math::MathDialect", "arith::ArithDialect",
27+
"tensor::TensorDialect", "linalg::LinalgDialect", "linalgx::LinalgxDialect"
3328
];
3429
}
3530
#endif
3631

3732
#ifdef GC_USE_IMEX
3833
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
3934
let summary = "Convert linalg dialect to XeGPU dialect.";
40-
let description = [{
41-
Lower linalg ops to XeGPU dialect.
42-
}];
43-
let dependentDialects = ["linalg::LinalgDialect",
44-
"gpu::GPUDialect",
45-
"xegpu::XeGPUDialect",
46-
"scf::SCFDialect",
47-
"memref::MemRefDialect",
48-
"arith::ArithDialect",
49-
"math::MathDialect",
50-
"vector::VectorDialect"];
35+
let description = [{Lower linalg ops to XeGPU dialect.}];
36+
let dependentDialects = [
37+
"linalg::LinalgDialect", "gpu::GPUDialect", "xegpu::XeGPUDialect",
38+
"scf::SCFDialect", "memref::MemRefDialect", "arith::ArithDialect",
39+
"math::MathDialect", "vector::VectorDialect"
40+
];
5141
let options = [
5242
Option<"kTile", "k-tile", "int64_t",
53-
/*default=*/"32",
54-
"GEMM tile size for reduction dimension.">,
43+
/*default=*/"32", "GEMM tile size for reduction dimension.">,
5544
Option<"stages", "stages", "int64_t",
56-
/*default=*/"1",
57-
"Number of cooperative prefetch stages.">,
45+
/*default=*/"1", "Number of cooperative prefetch stages.">,
5846
ListOption<"dpasTile", "dpas-tile", "int64_t",
5947
"DPAS register block sizes MxNxK">,
6048
];
6149
}
6250
#endif
6351

52+
def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
53+
let summary = "Verify the target description from ModuleOp DLTI attribute.";
54+
let description = [{
55+
Verify the target description from ModuleOp DLTI attribute. Raise error for unexpected input(such as a negative number of num_threads), and raise warn for missing fields, and provide a default value(such as 32K for L1_cache_size).
56+
}];
57+
let dependentDialects = ["DLTIDialect"];
58+
let options = [
59+
Option<"device", "device", "std::string",
60+
/*default=*/"\"CPU\"",
61+
"The device to verify. Supported device: CPU, ">,
62+
];
63+
}
64+
6465
#endif // GC_DIALECT_GC_PASSES

lib/gc/Analysis/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
2+
MLIRIR
3+
MLIRSupport)
4+
5+
gc_add_mlir_library(GCAnalysis
6+
TargetDescriptionAnalysis.cpp
7+
8+
DEPENDS
9+
GraphCompilerPassIncGen
10+
11+
LINK_LIBS PUBLIC
12+
${mlir_dialect_libs}
13+
${MLIR_LINK_COMPONENTS}
14+
GcInterface
15+
)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//===-- TargetDescriptionAnalysis.cpp - target description impl -*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
10+
#include <limits>
11+
#include <llvm/Support/Debug.h>
12+
13+
namespace mlir {
14+
namespace gc {
15+
16+
#define DEBUG_TYPE "target-description-analysis"
17+
18+
llvm::DenseMap<DeviceType, std::string>
19+
TargetDescriptionAnalysisBase::DeviceKeyMap = {
20+
{CPU, "CPU"},
21+
};
22+
23+
// default values for properties
24+
llvm::DenseMap<StringRef, int64_t>
25+
CPUTargetDescriptionAnalysis::CPUTargetDeafultValueMap = {
26+
{CPUTargetDescriptionAnalysis::kNumThreads, 1},
27+
{CPUTargetDescriptionAnalysis::kL1CacheSize, 32 * 1024},
28+
{CPUTargetDescriptionAnalysis::kL2CacheSize, 1024 * 1024},
29+
{CPUTargetDescriptionAnalysis::kL3CacheSize, 32 * 1024 * 1024},
30+
{CPUTargetDescriptionAnalysis::kMaxVectorWidth, 512},
31+
};
32+
33+
template <typename T>
34+
void TargetDescriptionAnalysisBase::emitNotFoundWarning(Location loc,
35+
StringRef key,
36+
T value) {
37+
mlir::emitWarning(loc) << key << " not found, using default value " << value;
38+
}
39+
40+
static int64_t getIntFromAttribute(Attribute attr) {
41+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
42+
if (intAttr.getType().isSignedInteger())
43+
return intAttr.getSInt();
44+
else if (intAttr.getType().isUnsignedInteger())
45+
return intAttr.getUInt();
46+
else
47+
return intAttr.getInt();
48+
}
49+
llvm_unreachable("Not an integer attribute");
50+
}
51+
52+
std::optional<Attribute>
53+
TargetDescriptionAnalysisBase::getPropertyValue(StringRef key) {
54+
return layout.getDevicePropertyValue(
55+
Builder(getContext())
56+
.getStringAttr(DeviceKeyMap[getDevice()] /* device ID*/),
57+
Builder(getContext()).getStringAttr(key));
58+
}
59+
60+
size_t CPUTargetDescriptionAnalysis::getNumThreads() {
61+
std::optional<Attribute> numThreads = getPropertyValue(kNumThreads);
62+
63+
if (numThreads && isa<IntegerAttr>(*numThreads))
64+
return getIntFromAttribute(*numThreads);
65+
emitNotFoundWarning(getLocation(), kNumThreads,
66+
CPUTargetDeafultValueMap[kNumThreads]);
67+
return CPUTargetDeafultValueMap[kNumThreads];
68+
}
69+
70+
size_t CPUTargetDescriptionAnalysis::getCacheSize(uint8_t cacheLevel) {
71+
assert(cacheLevel > 0 && cacheLevel < 4 && "Invalid cache level");
72+
StringLiteral key = "";
73+
if (cacheLevel == 1)
74+
key = kL1CacheSize;
75+
else if (cacheLevel == 2)
76+
key = kL2CacheSize;
77+
else if (cacheLevel == 3)
78+
key = kL3CacheSize;
79+
80+
std::optional<Attribute> cacheSize = getPropertyValue(key);
81+
if (cacheSize && isa<IntegerAttr>(*cacheSize))
82+
return getIntFromAttribute(*cacheSize);
83+
84+
emitNotFoundWarning(getLocation(), key, CPUTargetDeafultValueMap[key]);
85+
return CPUTargetDeafultValueMap[key];
86+
}
87+
88+
size_t CPUTargetDescriptionAnalysis::getMaxVectorWidth() {
89+
std::optional<Attribute> maxVectorWidth = getPropertyValue(kMaxVectorWidth);
90+
if (maxVectorWidth && isa<IntegerAttr>(*maxVectorWidth))
91+
return getIntFromAttribute(*maxVectorWidth);
92+
emitNotFoundWarning(getLocation(), kMaxVectorWidth,
93+
CPUTargetDeafultValueMap[kMaxVectorWidth]);
94+
return CPUTargetDeafultValueMap[kMaxVectorWidth];
95+
}
96+
97+
} // namespace gc
98+
} // namespace mlir

lib/gc/CAPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(GC_ALL_LIBS
22
${GC_ONEDNN_DIALECT_LIB_NAME}
33
GcPasses
4+
GCAnalysis
45
MLIRCPURuntimeTransforms)
56

67
if(GC_ENABLE_IMEX)

lib/gc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(Analysis)
12
add_subdirectory(CAPI)
23
add_subdirectory(Dialect)
34
add_subdirectory(Transforms)

lib/gc/ExecutionEngine/Driver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ gc_add_mlir_library(GcJitWrapper
3939
${dialect_libs}
4040
${conversion_libs}
4141
${GC_PASSES}
42+
GCAnalysis
4243
)

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ gc_add_mlir_library(GcPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16+
VerifyTargetDescription.cpp
1617

1718
DEPENDS
1819
GraphCompilerPassIncGen

0 commit comments

Comments
 (0)