Skip to content

Commit 748b698

Browse files
author
Longsheng Du
authored
[GPU] Add MLP test and linalg.fill lowering in 'linalg-to-xegpu' (#220)
1 parent 3c668aa commit 748b698

File tree

6 files changed

+158
-4
lines changed

6 files changed

+158
-4
lines changed

.github/workflows/build-llvm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
3333
- uses: actions/checkout@v4
3434
with:
35-
repository: Menooker/mlir-extensions
35+
repository: intel/mlir-extensions
3636
ref: ${{ env.IMEX_HASH }}
3737
path: mlir-extensions
3838
if: ${{ matrix.build-type == 'IMEX' }}

cmake/imex-version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ee459724294e165e360e1de72ad3b217eb9b6206
1+
6c2e414a953b9a118bce6adac21cf9d42630e674

cmake/imex.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ if (NOT DEFINED IMEX_INCLUDES)
1414

1515
# TODO: Change to main https://github.com/intel/mlir-extensions when all the
1616
# required functionality is merged.
17-
gc_fetch_content(imex "${IMEX_HASH}" https://github.com/Menooker/mlir-extensions
17+
gc_fetch_content(imex "${IMEX_HASH}" https://github.com/intel/mlir-extensions
1818
SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0
1919
)
2020

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,92 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern<LinalgOpTy> {
13941394
LinalgToXeGPUOptions options;
13951395
};
13961396

1397+
// Create XeGPU kernel out of memory fill operation.
1398+
LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
1399+
PatternRewriter &rewriter) {
1400+
Location loc = linalgOp.getLoc();
1401+
auto ctx = linalgOp.getContext();
1402+
1403+
auto scalar = linalgOp.getDpsInputs()[0];
1404+
auto output = linalgOp.getDpsInits()[0];
1405+
auto outputType = cast<ShapedType>(output.getType());
1406+
auto outputShape = outputType.getShape();
1407+
1408+
// Extract SIMD sized sub-tiles
1409+
int maxSizeSIMD = 256;
1410+
int64_t subTileCols = outputShape[1];
1411+
int64_t subTileRows = std::min(outputShape[0], maxSizeSIMD / subTileCols);
1412+
1413+
// Output descriptors for later stores.
1414+
SmallVector<Value> outputTiles = createDescriptorTiles(
1415+
rewriter, loc, output, outputShape, {0, 0}, {subTileRows, subTileCols});
1416+
1417+
SmallVector<Value> results;
1418+
for (size_t i = 0; i < outputTiles.size(); i++) {
1419+
// Operands are sub-tiles at the same location.
1420+
auto flatType = VectorType::get({subTileRows * subTileCols},
1421+
outputType.getElementType());
1422+
auto tileType = VectorType::get({subTileRows, subTileCols},
1423+
outputType.getElementType());
1424+
Value vec = rewriter.create<vector::BroadcastOp>(loc, flatType, scalar);
1425+
Value res = rewriter.create<vector::ShapeCastOp>(loc, tileType, vec);
1426+
1427+
if (!res)
1428+
return failure();
1429+
1430+
results.push_back(res);
1431+
}
1432+
1433+
// Store results.
1434+
auto writeCacheHint =
1435+
xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK);
1436+
for (size_t i = 0; i < outputTiles.size(); i++) {
1437+
rewriter.create<xegpu::StoreNdOp>(loc, results[i], outputTiles[i],
1438+
/*l1_hint=*/writeCacheHint,
1439+
/*l2_hint=*/writeCacheHint,
1440+
/*l3_hint=*/writeCacheHint);
1441+
}
1442+
1443+
rewriter.eraseOp(linalgOp);
1444+
1445+
return success();
1446+
}
1447+
1448+
// Convert a named fill operation to an XeGPU kernel.
1449+
template <typename LinalgOpTy>
1450+
struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
1451+
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1452+
1453+
ConvertMemoryFillToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options)
1454+
: OpRewritePattern<LinalgOpTy>(ctx), options(options) {}
1455+
1456+
LogicalResult matchAndRewrite(LinalgOpTy linalgOp,
1457+
PatternRewriter &rewriter) const override {
1458+
if (!linalgOp.hasPureBufferSemantics()) {
1459+
return rewriter.notifyMatchFailure(
1460+
linalgOp, "Linalg eltwise to GPU expects memref type");
1461+
}
1462+
if (linalgOp.hasDynamicShape()) {
1463+
return rewriter.notifyMatchFailure(
1464+
linalgOp, "Expect static shape when mapping to GPU");
1465+
}
1466+
auto isInputValid =
1467+
success(linalgOp.isScalar(linalgOp.getDpsInputOperand(0)));
1468+
if (failed(isInputValid))
1469+
return isInputValid;
1470+
1471+
auto isOutputValid =
1472+
isValidMemrefOperand(linalgOp, linalgOp.getDpsInits()[0], rewriter);
1473+
if (failed(isOutputValid))
1474+
return isOutputValid;
1475+
1476+
return createMemoryFillKernel(linalgOp, rewriter);
1477+
}
1478+
1479+
private:
1480+
LinalgToXeGPUOptions options;
1481+
};
1482+
13971483
// TODO: Finalize BRGEMM support and register the pattern.
13981484
void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns,
13991485
LinalgToXeGPUOptions options) {
@@ -1418,6 +1504,12 @@ void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns,
14181504
options);
14191505
}
14201506

1507+
void populateLinalgMemoryFillToXeGPUPatterns(RewritePatternSet &patterns,
1508+
LinalgToXeGPUOptions options) {
1509+
patterns.add<ConvertMemoryFillToXeGPU<linalg::FillOp>>(patterns.getContext(),
1510+
options);
1511+
}
1512+
14211513
struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
14221514
using LinalgToXeGPUBase::LinalgToXeGPUBase;
14231515

@@ -1429,6 +1521,11 @@ struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
14291521
populateLinalgGemmToXeGPUPatterns(gemmPatterns, options);
14301522
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(gemmPatterns));
14311523

1524+
// Convert memory fill ops.
1525+
RewritePatternSet fillPatterns(&getContext());
1526+
populateLinalgMemoryFillToXeGPUPatterns(fillPatterns, options);
1527+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(fillPatterns));
1528+
14321529
// Convert other remaining ops.
14331530
RewritePatternSet patterns(&getContext());
14341531
populateLinalgEltwiseToXeGPUPatterns(patterns, options);

scripts/compile.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ build_llvm() {
120120
local mlir_ext_dir="$EXTERNALS_DIR/mlir-extensions"
121121
if ! [ -d "$mlir_ext_dir" ]; then
122122
cd "$EXTERNALS_DIR"
123-
git clone https://github.com/Menooker/mlir-extensions.git
123+
git clone https://github.com/intel/mlir-extensions.git
124124
cd "$mlir_ext_dir"
125125
else
126126
cd "$mlir_ext_dir"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: gc-opt %s --pass-pipeline='builtin.module(func.func(iterative-tiling-and-fusion{use-cost-model=0 default-tile-size=matmul:{16,16}}),eliminate-empty-tensors,empty-tensor-to-alloc-tensor,one-shot-bufferize{bufferize-function-boundaries=1 function-boundary-type-conversion=identity-layout-map},drop-equivalent-buffer-results,func.func(finalizing-bufferize),canonicalize,cse,drop-equivalent-buffer-results,expand-realloc,canonicalize,ownership-based-buffer-deallocation,canonicalize,buffer-deallocation-simplification,bufferization-lower-deallocations,cse,canonicalize,convert-bufferization-to-memref,func.func(scf-forall-to-parallel),func.func(linalg-to-xegpu{stages=1 dpas-tile=8,16,16 k-tile=16}),xegpu-fold-alias-ops,func.func(convert-linalg-to-parallel-loops),func.func(gpu-map-parallel-loops),func.func(convert-parallel-loops-to-gpu),func.func(insert-gpu-allocs),gpu-kernel-outlining,canonicalize,set-spirv-capabilities{client-api=opencl},gpu.module(set-spirv-abi-attrs{client-api=opencl}),lower-affine,imex-vector-linearize,gpu.module(convert-xegpu-to-vc),reconcile-unrealized-casts,bf16-to-gpu,gpu.module(convert-func-to-spirv),gpu.module(convert-vector-to-spirv),imex-convert-gpu-to-spirv,spirv.module(spirv-lower-abi-attrs,spirv-update-vce),func.func(llvm-request-c-wrappers),serialize-spirv,convert-vector-to-scf,convert-gpu-to-gpux,convert-scf-to-cf,convert-cf-to-llvm,convert-vector-to-llvm,convert-index-to-llvm,convert-arith-to-llvm,convert-func-to-llvm,convert-math-to-llvm,convert-gpux-to-llvm,convert-index-to-llvm,expand-strided-metadata,lower-affine,finalize-memref-to-llvm,reconcile-unrealized-casts)' \
2+
// RUN: | gc-cpu-runner -e main --entry-point-result=void \
3+
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s
4+
5+
module {
6+
func.func @linalg_mlp(%arg0: tensor<32x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2 : tensor<32x4096xf16>,
7+
%arg3: tensor<4096x4096xf16>, %arg4 : tensor<32x4096xf16>) {
8+
%cst = arith.constant 0.000000e+00 : f16
9+
%0 = tensor.empty() : tensor<32x4096xf16>
10+
%1 = linalg.fill ins(%cst : f16) outs(%0 : tensor<32x4096xf16>) -> tensor<32x4096xf16>
11+
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<32x4096xf16>, tensor<4096x4096xf16>)
12+
outs(%1 : tensor<32x4096xf16>) -> (tensor<32x4096xf16>)
13+
%3 = tensor.empty() : tensor<32x4096xf16>
14+
%4 = linalg.add ins(%arg2, %2 : tensor<32x4096xf16>, tensor<32x4096xf16>)
15+
outs(%3 : tensor<32x4096xf16>) -> tensor<32x4096xf16>
16+
%5 = arith.constant dense<0.000000e+00> : tensor<32x4096xf16>
17+
%6 = tensor.empty() : tensor<32x4096xf16>
18+
%7 = linalg.max ins(%5, %4 : tensor<32x4096xf16>, tensor<32x4096xf16>)
19+
outs(%6 : tensor<32x4096xf16>) -> tensor<32x4096xf16>
20+
21+
%8 = tensor.empty() : tensor<32x4096xf16>
22+
%9 = linalg.fill ins(%cst : f16) outs(%8 : tensor<32x4096xf16>) -> tensor<32x4096xf16>
23+
%10 = linalg.matmul ins(%7, %arg3 : tensor<32x4096xf16>, tensor<4096x4096xf16>)
24+
outs(%9 : tensor<32x4096xf16>) -> (tensor<32x4096xf16>)
25+
%11 = tensor.empty() : tensor<32x4096xf16>
26+
%12 = linalg.add ins(%arg4, %10 : tensor<32x4096xf16>, tensor<32x4096xf16>)
27+
outs(%11 : tensor<32x4096xf16>) -> tensor<32x4096xf16>
28+
%13 = arith.constant dense<0.000000e+00> : tensor<32x4096xf16>
29+
%14 = tensor.empty() : tensor<32x4096xf16>
30+
%15 = linalg.max ins(%13, %12 : tensor<32x4096xf16>, tensor<32x4096xf16>)
31+
outs(%14 : tensor<32x4096xf16>) -> tensor<32x4096xf16>
32+
33+
%slice = tensor.extract_slice %15[0, 0][32, 1][1, 1] : tensor<32x4096xf16> to tensor<32xf16>
34+
%cast = tensor.cast %slice : tensor<32xf16> to tensor<*xf16>
35+
call @printMemrefF16(%cast) : (tensor<*xf16>) -> ()
36+
37+
return
38+
}
39+
40+
func.func @main() {
41+
%0 = arith.constant dense<0.01> : tensor<32x4096xf16>
42+
%1 = arith.constant dense<0.01> : tensor<4096x4096xf16>
43+
%2 = arith.constant dense<0.02> : tensor<32x4096xf16>
44+
%3 = arith.constant dense<0.01> : tensor<4096x4096xf16>
45+
%4 = arith.constant dense<0.02> : tensor<32x4096xf16>
46+
47+
func.call @linalg_mlp(%0, %1, %2, %3, %4) : (tensor<32x4096xf16>, tensor<4096x4096xf16>, tensor<32x4096xf16>,
48+
tensor<4096x4096xf16>, tensor<32x4096xf16>) -> ()
49+
return
50+
}
51+
52+
func.func private @printMemrefF16(%ptr : tensor<*xf16>) attributes { llvm.emit_c_interface }
53+
}
54+
55+
// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}}
56+
// CHECK-SAME: rank = 1 offset = 0 sizes = [32] strides = [4096] data =
57+
// CHECK-NEXT: [17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625, 17.625]

0 commit comments

Comments
 (0)