Skip to content

Centralize target description query through DLTI and add verifier pass #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions include/gc/Analysis/TargetDescriptionAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===-- TargetDescriptionAnalysis.h - target description class --*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
#define MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H

#include "gc/Dialect/Linalgx/LinalgxOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "llvm/ADT/StringRef.h"

namespace mlir {
namespace gc {

using namespace mlir;

enum DeviceType { CPU = 0 };

class TargetDescriptionAnalysisBase {
public:
TargetDescriptionAnalysisBase(Operation *op, DeviceType device)
: ctx(op->getContext()), device(device),
layout(isa<ModuleOp>(op) ? dyn_cast<ModuleOp>(op)
: op->getParentOfType<ModuleOp>()),
loc(op->getLoc()) {}

// get the device ID
DeviceType getDevice() { return device; }

// get the MLIR context
MLIRContext *getContext() { return ctx; }

// get the data layout
DataLayout getLayout() { return layout; }

// get the property value by key
std::optional<Attribute> getPropertyValue(StringRef key);

// get the location
Location getLocation() { return loc; }

// check if the property exists
bool hasProperty(StringRef key) { return getPropertyValue(key).has_value(); }

// emit warning if the property is not found
template <typename T>
void emitNotFoundWarning(Location loc, StringRef key, T value);

// the map from device type to device string
static llvm::DenseMap<DeviceType, std::string> DeviceKeyMap;

private:
MLIRContext *ctx;
DeviceType device;
DataLayout layout;
Location loc;
};

class CPUTargetDescriptionAnalysis : public TargetDescriptionAnalysisBase {
public:
static constexpr StringLiteral kL1CacheSize = "L1_cache_size_in_bytes";
static constexpr StringLiteral kL2CacheSize = "L2_cache_size_in_bytes";
static constexpr StringLiteral kL3CacheSize = "L3_cache_size_in_bytes";
static constexpr StringLiteral kMaxVectorWidth = "max_vector_width";
static constexpr StringLiteral kNumThreads = "num_threads";

// get runtime OMP_NUM_THREADS
unsigned getNumThreads();

// get cache size by cacheLevel
unsigned getCacheSize(uint8_t cacheLevel);

// get the maximum vector length in bits
unsigned getMaxVectorWidth();

CPUTargetDescriptionAnalysis(Operation *op)
: TargetDescriptionAnalysisBase(op, DeviceType::CPU) {}
};

} // namespace gc
} // namespace mlir

#endif
51 changes: 26 additions & 25 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,30 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {

#ifdef GC_HAS_ONEDNN_DIALECT
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
let description = [{
Lowers the `onednn_graph` ops to `linalg` ops.
}];
let summary =
"Lower the operations from the oneDNN Graph dialect into Linalg";
let description = [{Lowers the `onednn_graph` ops to `linalg` ops.}];
let dependentDialects = [
"func::FuncDialect",
"math::MathDialect",
"arith::ArithDialect",
"tensor::TensorDialect",
"linalg::LinalgDialect",
"linalgx::LinalgxDialect"
"func::FuncDialect", "math::MathDialect", "arith::ArithDialect",
"tensor::TensorDialect", "linalg::LinalgDialect", "linalgx::LinalgxDialect"
];
}
#endif

#ifdef GC_USE_IMEX
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
let summary = "Convert linalg dialect to XeGPU dialect.";
let description = [{
Lower linalg ops to XeGPU dialect.
}];
let dependentDialects = ["linalg::LinalgDialect",
"gpu::GPUDialect",
"xegpu::XeGPUDialect",
"scf::SCFDialect",
"memref::MemRefDialect",
"arith::ArithDialect",
"math::MathDialect",
"vector::VectorDialect"];
let description = [{Lower linalg ops to XeGPU dialect.}];
let dependentDialects = [
"linalg::LinalgDialect", "gpu::GPUDialect", "xegpu::XeGPUDialect",
"scf::SCFDialect", "memref::MemRefDialect", "arith::ArithDialect",
"math::MathDialect", "vector::VectorDialect"
];
let options = [
Option<"kTile", "k-tile", "int64_t",
/*default=*/"32",
"GEMM tile size for reduction dimension.">,
/*default=*/"32", "GEMM tile size for reduction dimension.">,
Option<"stages", "stages", "int64_t",
/*default=*/"1",
"Number of cooperative prefetch stages.">,
/*default=*/"1", "Number of cooperative prefetch stages.">,
ListOption<"dpasTile", "dpas-tile", "int64_t",
"DPAS register block sizes MxNxK">,
];
Expand Down Expand Up @@ -93,4 +81,17 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
];
}

def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
let summary = "Verify the target description from ModuleOp DLTI attribute.";
let description = [{
Verify the target description from ModuleOp DLTI attribute. Raise error for unexpected input(such as a negative number of num_threads), and raise warn for missing fields, and provide a default value(such as 32K for L1_cache_size).
}];
let dependentDialects = ["DLTIDialect"];
let options = [
Option<"device", "device", "std::string",
/*default=*/"\"CPU\"",
"The device to verify. Supported device: CPU, ">,
];
}

#endif // GC_DIALECT_GC_PASSES
15 changes: 15 additions & 0 deletions lib/gc/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
MLIRIR
MLIRSupport)

gc_add_mlir_library(GcAnalysis
TargetDescriptionAnalysis.cpp

DEPENDS
GraphCompilerPassIncGen

LINK_LIBS PUBLIC
${mlir_dialect_libs}
${MLIR_LINK_COMPONENTS}
GcInterface
)
102 changes: 102 additions & 0 deletions lib/gc/Analysis/TargetDescriptionAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===-- TargetDescriptionAnalysis.cpp - target description impl -*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "gc/Analysis/TargetDescriptionAnalysis.h"
#include <limits>
#include <llvm/Support/Debug.h>
#include <regex>

namespace mlir {
namespace gc {

#define DEBUG_TYPE "target-description-analysis"

llvm::DenseMap<DeviceType, std::string>
TargetDescriptionAnalysisBase::DeviceKeyMap = {
{CPU, "CPU"},
};

template <typename T>
void TargetDescriptionAnalysisBase::emitNotFoundWarning(Location loc,
StringRef key,
T value) {
mlir::emitWarning(loc) << key << " not found, using default value " << value;
}

static bool isIntegerNumber(const std::string &token) {
return std::regex_match(token, std::regex(("(\\+|-)?[[:digit:]]+")));
}

static int64_t getIntFromAttribute(Attribute attr) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType().isSignedInteger())
return intAttr.getSInt();
else if (intAttr.getType().isUnsignedInteger())
return intAttr.getUInt();
else
return intAttr.getInt();
} else if (auto strAttr = dyn_cast<StringAttr>(attr)) {
std::string str = strAttr.getValue().str();
if (isIntegerNumber(str))
return std::stoll(str);
}
llvm_unreachable("Not an integer attribute or integer like string attribute");
}

std::optional<Attribute>
TargetDescriptionAnalysisBase::getPropertyValue(StringRef key) {
return layout.getDevicePropertyValue(
Builder(getContext())
.getStringAttr(DeviceKeyMap[getDevice()] /* device ID*/),
Builder(getContext()).getStringAttr(key));
}

unsigned CPUTargetDescriptionAnalysis::getNumThreads() {
static const unsigned defaultNumThreads = 1;
std::optional<Attribute> numThreads = getPropertyValue(kNumThreads);

if (numThreads)
return getIntFromAttribute(*numThreads);
emitNotFoundWarning(getLocation(), kNumThreads, defaultNumThreads);
return defaultNumThreads;
}

unsigned CPUTargetDescriptionAnalysis::getCacheSize(uint8_t cacheLevel) {
assert(cacheLevel > 0 && cacheLevel < 4 && "Invalid cache level");
llvm::DenseMap<StringRef, unsigned> CPUTargetCacheSizeValueMap = {
{CPUTargetDescriptionAnalysis::kL1CacheSize, 32 * 1024},
{CPUTargetDescriptionAnalysis::kL2CacheSize, 1024 * 1024},
{CPUTargetDescriptionAnalysis::kL3CacheSize, 32 * 1024 * 1024},
};
StringLiteral key = "";
if (cacheLevel == 1)
key = kL1CacheSize;
else if (cacheLevel == 2)
key = kL2CacheSize;
else if (cacheLevel == 3)
key = kL3CacheSize;

std::optional<Attribute> cacheSize = getPropertyValue(key);
if (cacheSize)
return getIntFromAttribute(*cacheSize);

emitNotFoundWarning(getLocation(), key, CPUTargetCacheSizeValueMap[key]);
return CPUTargetCacheSizeValueMap[key];
}

unsigned CPUTargetDescriptionAnalysis::getMaxVectorWidth() {
static const unsigned defaultMaxVectorWidth = 512;
std::optional<Attribute> maxVectorWidth = getPropertyValue(kMaxVectorWidth);
if (maxVectorWidth)
return getIntFromAttribute(*maxVectorWidth);
emitNotFoundWarning(getLocation(), kMaxVectorWidth, defaultMaxVectorWidth);
return defaultMaxVectorWidth;
}

} // namespace gc
} // namespace mlir
1 change: 1 addition & 0 deletions lib/gc/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(GC_ALL_LIBS
${GC_ONEDNN_DIALECT_LIB_NAME}
GcPasses
GcAnalysis
MLIRCPURuntimeTransforms)

if(GC_ENABLE_IMEX)
Expand Down
1 change: 1 addition & 0 deletions lib/gc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Analysis)
add_subdirectory(CAPI)
add_subdirectory(Dialect)
add_subdirectory(Transforms)
Expand Down
1 change: 1 addition & 0 deletions lib/gc/ExecutionEngine/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ gc_add_mlir_library(GcJitWrapper
${dialect_libs}
${conversion_libs}
${GC_PASSES}
GcAnalysis
)
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ gc_add_mlir_library(GcPasses
TileNamed.cpp
IterativeTilingAndFusion.cpp
TilingUsingInterfaceX.cpp
VerifyTargetDescription.cpp

DEPENDS
GraphCompilerPassIncGen
Expand Down
60 changes: 3 additions & 57 deletions lib/gc/Transforms/IterativeTilingAndFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "gc/Analysis/TargetDescriptionAnalysis.h"
#include "gc/Transforms/Passes.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/DLTI/Traits.h"
Expand Down Expand Up @@ -579,62 +580,6 @@ static LogicalResult isSelfTiledOp(Operation *targetOp) {
return success(walkResult.wasInterrupted());
}

struct SystemDesc {
// get runtime OMP_NUM_THREADS
uint32_t getNumThreads() {
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("num_threads"));
if (numThreads && isa<IntegerAttr>(*numThreads)) {
return dyn_cast<IntegerAttr>(*numThreads).getInt();
}
return 1;
}
// get cache size by cacheLevel
size_t getCacheSize(uint8_t cacheLevel) {
if (cacheLevel == 1) {
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
}
} else if (cacheLevel == 2) {
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
}
} else if (cacheLevel == 3) {
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
}
}
return 0;
}

// get the maximum vector length in bits
size_t getMaxVectorLength() {
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("max_vector_width"));
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
}
return 512;
}

SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}

private:
DataLayout layout;
MLIRContext *ctx;
};

using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;

template <typename OpTy>
Expand Down Expand Up @@ -806,7 +751,8 @@ struct IterativeTilingAndFusion
// Get funcOp
func::FuncOp func = getOperation();
// Get system descriptor
SystemDesc sysDesc(func->getParentOfType<ModuleOp>());
CPUTargetDescriptionAnalysis sysDesc =
getAnalysis<CPUTargetDescriptionAnalysis>();
// Flexible options to control which candidate slice would be selected from
// the view of both validity and performance.
CandidateSliceOptions sliceOptions;
Expand Down
Loading