-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][xegpu] Change index arithmetic ops to arith ops.
#170390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Charitha Saumya (charithaintc) ChangesIndex ops cause some issues during SIMT distribution because they don't have the Patch is 32.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170390.diff 8 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index fb5d1e758dbd1..dfbe20e5087b1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -61,7 +62,7 @@ genCoordinates(OpBuilder &builder, Location loc,
// Get the offset of `subShape` within a distribution unit.
SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
- return builder.createOrFold<index::MulOp>(
+ return builder.createOrFold<arith::MulIOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
@@ -84,7 +85,7 @@ genCoordinates(OpBuilder &builder, Location loc,
// Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
- return builder.createOrFold<index::RemUOp>(
+ return builder.createOrFold<arith::RemUIOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
@@ -343,7 +344,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
/// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
/// this dimension)
result[dimIdx] =
- builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+ builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
/// Update remaining for the next dimension by removing what we've already
/// processed. Division tells us "how many complete groups of this dimension
@@ -352,7 +353,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
/// no next dimension to process
if (i < order.size() - 1) {
remaining =
- builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+ builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
}
}
return result;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 91432b1c11304..2d5bfca7e4481 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
@@ -527,7 +528,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
- results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
}
return results;
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 8fd3cca5594cb..22177f8f6a15f 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -271,11 +271,11 @@ gpu.module @xevm_module{
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C8]]
-// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C8]]
-// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C2]]
-// CHECK: %[[REMU3:.*]] = index.remu %[[REMU2]], %[[C2]]
-// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C8]]
+// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C8]]
+// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C8]]
+// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C2]]
+// CHECK: %[[REMU3:.*]] = arith.remui %[[REMU2]], %[[C2]]
+// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C8]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
@@ -294,13 +294,13 @@ gpu.module @xevm_module{
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C4]]
-// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C4]]
-// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C4]]
-// CHECK: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2]]
-// CHECK: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C8]]
-// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C4]]
-// CHECK: %[[ADD:.*]] = index.add %[[REMU4]], %[[C1]]
+// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C4]]
+// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C4]]
+// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C4]]
+// CHECK: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C2]]
+// CHECK: %[[REMU3:.*]] = arith.remui %[[MUL]], %[[C8]]
+// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C4]]
+// CHECK: %[[ADD:.*]] = arith.addi %[[REMU4]], %[[C1]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index 02c5f71d5c83d..8ce6d4dfd439e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -3,10 +3,10 @@
gpu.module @test {
gpu.func @slice_attr() -> vector<128xindex> {
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
- // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
- // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
- // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+ // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C8:.*]]
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU]], %[[C4:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]]
+ // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]]
// CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
// CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
@@ -16,11 +16,10 @@ gpu.module @test {
gpu.func @nested_slice_attr() -> vector<128xindex> {
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
- // CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
- // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
- // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
- // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+ // CHECK-DAG: %[[DIVU2:.*]] = arith.divui %[[SGID]], %[[C8:.*]]
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU2]], %[[C4:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]]
+ // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]]
// CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
// CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
@@ -29,4 +28,3 @@ gpu.module @test {
}
}
-
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 01134d8eaabec..4829af3612de3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -16,18 +16,18 @@ gpu.module @test_round_robin_assignment {
gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
- // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
- // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+ // CHECK: %[[IDX:.*]] = arith.remui %[[SGID]], %[[C4]]
+ // CHECK: %[[IDY_DIV:.*]] = arith.divui %[[SGID]], %[[C4]]
// CHECK: %[[C8:.*]] = arith.constant 8 : index
- // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
+ // CHECK: %[[IDY:.*]] = arith.remui %[[IDY_DIV]], %[[C8]]
// CHECK: %[[C16:.*]] = arith.constant 16 : index
- // CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]]
+ // CHECK: %[[LY:.*]] = arith.muli %[[IDY]], %[[C16]]
// CHECK: %[[C64:.*]] = arith.constant 64 : index
- // CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]]
+ // CHECK: %[[LX:.*]] = arith.muli %[[IDX]], %[[C64]]
// CHECK: %[[C128:.*]] = arith.constant 128 : index
- // CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]]
+ // CHECK: %[[OFFY:.*]] = arith.remui %[[LY]], %[[C128]]
// CHECK: %[[C64_1:.*]] = arith.constant 64 : index
- // CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]]
+ // CHECK: %[[OFFX:.*]] = arith.remui %[[LX]], %[[C64_1]]
// CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 1cddccb5fbbd1..eae51a16053d8 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -90,30 +90,27 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: non_splat_constant
gpu.func @non_splat_constant() {
- // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{.*}}0{{.*}}, {{.*}}16{{.*}}> : vector<2x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[SGID]], %[[C1:.*]]
- // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C1:.*]]
- // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
- // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2:.*]]
- // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C32:.*]]
- // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
- // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
- // CHECK-DAG: %[[REMU5:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
- // CHECK-DAG: %[[REMU6:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
- // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
- // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
- // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
- // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
- // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU5]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
- // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU6]], %[[C0:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
- // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
- // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
+ // CHECK-DAG: %[[T1:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[T2:.*]] = arith.muli %[[T1]], %[[C2:.*]] : index
+ // CHECK-DAG: %[[T3:.*]] = arith.remui %[[T2]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[T4:.*]] = arith.addi %[[T2]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[T5:.*]] = arith.remui %[[T4]], %[[C32_6:.*]] : index
+ // CHECK-DAG: %[[T6:.*]] = arith.muli %[[T3]], %[[C16_10:.*]] : index
+ // CHECK-DAG: %[[T7:.*]] = arith.addi %[[C0_11:.*]], %[[T6]] : index
+ // CHECK-DAG: %[[T8:.*]] = arith.muli %[[C0_4:.*]], %[[C0_9:.*]] : index
+ // CHECK-DAG: %[[T9:.*]] = arith.addi %[[T7]], %[[T8]] : index
+ // CHECK-DAG: %[[T10:.*]] = vector.broadcast %[[T9]] : index to vector<2x1xindex>
+ // CHECK-DAG: %[[T11:.*]] = arith.addi %[[CST]], %[[T10]] : vector<2x1xindex>
+ // CHECK-DAG: %[[T12:.*]] = arith.muli %[[T5]], %[[C16_10:.*]] : index
+ // CHECK-DAG: %[[T13:.*]] = arith.addi %[[C0_12:.*]], %[[T12]] : index
+ // CHECK-DAG: %[[T14:.*]] = arith.muli %[[C0_8:.*]], %[[C0_9:.*]] : index
+ // CHECK-DAG: %[[T15:.*]] = arith.addi %[[T13]], %[[T14]] : index
+ // CHECK-DAG: %[[T16:.*]] = vector.broadcast %[[T15]] : index to vector<2x1xindex>
+ // CHECK-DAG: %[[T17:.*]] = arith.addi %[[CST]], %[[T16]] : vector<2x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
@@ -139,4 +136,3 @@ gpu.module @test_distribution {
gpu.return
}
}
-
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 574b365443a0a..98920d61c4f58 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -27,17 +27,17 @@ gpu.module @test_distribution {
//CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
//CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
//CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
- //CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4]]
- //CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4]]
+ //CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4]]
+ //CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4]]
//CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
- //CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8]]
+ //CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8]]
//CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
- //CHECK-DAG: %[[L_OFF_Y:.*]] = index.mul %[[SGIDY]], %[[C32]]
- //CHECK-DAG: %[[L_OFF_X:.*]] = index.mul %[[SGIDX]], %[[C32]]
+ //CHECK-DAG: %[[L_OFF_Y:.*]] = arith.muli %[[SGIDY]], %[[C32]] : index
+ //CHECK-DAG: %[[L_OFF_X:.*]] = arith.muli %[[SGIDX]], %[[C32_1:.*]] : index
//CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
- //CHECK-DAG: %[[OFF_Y:.*]] = index.remu %[[L_OFF_Y]], %[[C256]]
+ //CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[L_OFF_Y]], %[[C256]] : index
//CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
- //CHECK-DAG: %[[OFF_X:.*]] = index.remu %[[L_OFF_X]], %[[C128]]
+ //CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[L_OFF_X]], %[[C128]] : index
//CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -293,7 +293,7 @@ gpu.module @test_distribution {
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
- xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
l1_hint = #xegpu.cache_hint<cached>}
@@ -321,18 +321,18 @@ gpu.module @test_distribution {
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
//CHECK: [[c4:%.+]] = arith.constant 4 : index
- //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]]
- //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]]
+ //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index
+ //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index
//CHECK: [[c2:%.+]] = arith.constant 2 : index
- //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]]
+ //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index
//CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]]
+ //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index
//CHECK: [[c32_0:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]]
+ //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index
//CHECK: [[c64:%.+]] = arith.constant 64 : index
- //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
+ //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
+ //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index
//CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
%0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
@@ -346,18 +346,18 @@ gpu.module @test_distribution {
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
//CHECK: [[c4:%.+]] = arith.constant 4 : index
- //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]]
- //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]]
+ //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index
+ //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index
//CHECK: [[c2:%.+]] = arith.constant 2 : index
- //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]]
+ //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index
//CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]]...
[truncated]
|
nbpatel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Index ops cause some issues during SIMT distribution because they don't have the
Elementwisemappable trait. This PR replaces all index arithmetic ops with matchingarithdialect ops.