Skip to content

[mlir] GPUToROCDL: Add support for non-i32/f32 shuffle types #136320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2025

Conversation

Hardcode84
Copy link
Contributor

Use recently added repacking utilities to support other datatypes.

Also, tighten gpu.shuffle verification to reject scalable vectors

Use recently added repacking utilities to support other datatypes.

Also, tighten `gpu.shuffle` verification to reject scalable vectors
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Use recently added repacking utilities to support other datatypes.

Also, tighten gpu.shuffle verification to reject scalable vectors


Full diff: https://github.com/llvm/llvm-project/pull/136320.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+2-2)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+12-14)
  • (removed) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir (-13)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+21)
  • (modified) mlir/test/Dialect/GPU/invalid.mlir (+11-3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 3241aff4b683c..68095b7bf5c59 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -840,7 +840,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
     -   a variadic number of Private memory attributions.
 
     The `kernelFunc` and `kernelModule` attributes are optional and specifies
-    the kernel name and a module in which the kernel should be outlined. 
+    the kernel name and a module in which the kernel should be outlined.
 
     Syntax:
 
@@ -1201,7 +1201,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
 }
 
 def AnyIntegerOrFloatOr1DVector :
-  AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
+  AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
 
 def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
   let summary = "Reduce values among subgroup.";
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index c6c695b442b4f..34129989049d0 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -62,8 +62,8 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
   return canBeBare;
 }
 
-Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
-                const unsigned indexBitwidth) {
+static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
+                       const unsigned indexBitwidth) {
   auto int32Type = IntegerType::get(rewriter.getContext(), 32);
   Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
   Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
@@ -138,10 +138,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Location loc = op->getLoc();
     Value initShflValue = adaptor.getValue();
     Type shflType = initShflValue.getType();
-    // TODO: Add support for non 32-bit shuffle values.
-    if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
-      return rewriter.notifyMatchFailure(
-          op, "only 32-bit int/float types are supported");
 
     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
@@ -179,15 +175,17 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
     Value dwordAlignedDstLane =
         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
-    if (shflType.isF32()) {
-      initShflValue =
-          rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
-    }
-    Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
-        loc, int32Type, dwordAlignedDstLane, initShflValue);
-    if (shflType.isF32()) {
-      shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
+
+    SmallVector<Value> decomposed =
+        LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
+    SmallVector<Value> swizzled;
+    for (Value v : decomposed) {
+      Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
+                                                       dwordAlignedDstLane, v);
+      swizzled.emplace_back(res);
     }
+    Value shflValue =
+        LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
     return success();
   }
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
deleted file mode 100644
index 90f2e5f047cd9..0000000000000
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
-
-gpu.module @test_module {
-  // ROCDL lowering only suport shuffles for 32bit ints/floats, but they
-  // shouldn't crash on unsupported types.
-  func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
-    %offset = arith.constant 4 : i32
-    %width = arith.constant 64 : i32
-    // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
-    %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
-    return %shfl : vector<4xf16>
-  }
-}
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index e23ab16ccd94b..071cae9d5789f 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -710,6 +710,27 @@ gpu.module @test_module {
     %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
     func.return %shfl, %shfli, %shfld : f32, f32, f32
   }
+
+  // CHECK-LABEL: func @gpu_shuffle_vec
+  //  CHECK-SAME: (%[[ARG:.*]]: vector<4xf16>, %{{.*}}: i32, %{{.*}}: i32)
+  func.func @gpu_shuffle_vec(%arg0: vector<4xf16>, %arg1: i32, %arg2: i32) -> vector<4xf16> {
+    // CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG]] : vector<4xf16> to vector<2xi32>
+    // CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[ELEM0:.*]] = llvm.extractelement %13[%[[IDX0]] : i32] : vector<2xi32>
+    // CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[ELEM1:.*]] = llvm.extractelement %13[%[[IDX1]] : i32] : vector<2xi32>
+    // CHECK: %[[PERM0:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM0]] : (i32, i32) -> i32
+    // CHECK: %[[PERM1:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM1]] : (i32, i32) -> i32
+    // CHECK: %[[V0:.*]] = llvm.mlir.poison : vector<2xi32>
+    // CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[V1:.*]] = llvm.insertelement %[[PERM0]], %[[V0]][%[[IDX0]] : i32] : vector<2xi32>
+    // CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[V2:.*]] = llvm.insertelement %[[PERM1]], %[[V1]][%[[IDX1]] : i32] : vector<2xi32>
+    // CHECK: %[[RES:.*]] = llvm.bitcast %[[V2]] : vector<2xi32> to vector<4xf16>
+    // CHECK: llvm.return %[[RES]] : vector<4xf16>
+    %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<4xf16>
+    func.return %shfl : vector<4xf16>
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 16148a493ce6e..ce1be7b5618fe 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -367,7 +367,7 @@ func.func @subgroup_reduce_cluster_stride_without_size(%arg0 : vector<4xf32>) {
 // -----
 
 func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
-  // expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
+  // expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or fixed-length vector of}}
   %res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
   return
 }
@@ -375,7 +375,7 @@ func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
 // -----
 
 func.func @subgroup_reduce_bad_type_scalable(%arg0 : vector<[2]xf32>) {
-  // expected-error@+1 {{is not compatible with scalable vector types}}
+  // expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or fixed-length vector of}}
   %res = gpu.subgroup_reduce add %arg0 : (vector<[2]xf32>) -> vector<[2]xf32>
   return
 }
@@ -463,13 +463,21 @@ func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
 // -----
 
 func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
-  // expected-error@+1 {{op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1, but got 'index'}}
+  // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
   %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index
   return
 }
 
 // -----
 
+func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %arg2 : i32) {
+  // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
+  %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<[4]xf32>
+  return
+}
+
+// -----
+
 module {
   gpu.module @gpu_funcs {
     // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Ivan Butygin (Hardcode84)

Changes

Use recently added repacking utilities to support other datatypes.

Also, tighten gpu.shuffle verification to reject scalable vectors


Full diff: https://github.com/llvm/llvm-project/pull/136320.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+2-2)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+12-14)
  • (removed) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir (-13)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+21)
  • (modified) mlir/test/Dialect/GPU/invalid.mlir (+11-3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 3241aff4b683c..68095b7bf5c59 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -840,7 +840,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
     -   a variadic number of Private memory attributions.
 
     The `kernelFunc` and `kernelModule` attributes are optional and specifies
-    the kernel name and a module in which the kernel should be outlined. 
+    the kernel name and a module in which the kernel should be outlined.
 
     Syntax:
 
@@ -1201,7 +1201,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
 }
 
 def AnyIntegerOrFloatOr1DVector :
-  AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
+  AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
 
 def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
   let summary = "Reduce values among subgroup.";
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index c6c695b442b4f..34129989049d0 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -62,8 +62,8 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
   return canBeBare;
 }
 
-Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
-                const unsigned indexBitwidth) {
+static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
+                       const unsigned indexBitwidth) {
   auto int32Type = IntegerType::get(rewriter.getContext(), 32);
   Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
   Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
@@ -138,10 +138,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Location loc = op->getLoc();
     Value initShflValue = adaptor.getValue();
     Type shflType = initShflValue.getType();
-    // TODO: Add support for non 32-bit shuffle values.
-    if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
-      return rewriter.notifyMatchFailure(
-          op, "only 32-bit int/float types are supported");
 
     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
@@ -179,15 +175,17 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
     Value dwordAlignedDstLane =
         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
-    if (shflType.isF32()) {
-      initShflValue =
-          rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
-    }
-    Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
-        loc, int32Type, dwordAlignedDstLane, initShflValue);
-    if (shflType.isF32()) {
-      shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
+
+    SmallVector<Value> decomposed =
+        LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
+    SmallVector<Value> swizzled;
+    for (Value v : decomposed) {
+      Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
+                                                       dwordAlignedDstLane, v);
+      swizzled.emplace_back(res);
     }
+    Value shflValue =
+        LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
     return success();
   }
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
deleted file mode 100644
index 90f2e5f047cd9..0000000000000
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
-
-gpu.module @test_module {
-  // ROCDL lowering only suport shuffles for 32bit ints/floats, but they
-  // shouldn't crash on unsupported types.
-  func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
-    %offset = arith.constant 4 : i32
-    %width = arith.constant 64 : i32
-    // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
-    %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
-    return %shfl : vector<4xf16>
-  }
-}
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index e23ab16ccd94b..071cae9d5789f 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -710,6 +710,27 @@ gpu.module @test_module {
     %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
     func.return %shfl, %shfli, %shfld : f32, f32, f32
   }
+
+  // CHECK-LABEL: func @gpu_shuffle_vec
+  //  CHECK-SAME: (%[[ARG:.*]]: vector<4xf16>, %{{.*}}: i32, %{{.*}}: i32)
+  func.func @gpu_shuffle_vec(%arg0: vector<4xf16>, %arg1: i32, %arg2: i32) -> vector<4xf16> {
+    // CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG]] : vector<4xf16> to vector<2xi32>
+    // CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[ELEM0:.*]] = llvm.extractelement %13[%[[IDX0]] : i32] : vector<2xi32>
+    // CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[ELEM1:.*]] = llvm.extractelement %13[%[[IDX1]] : i32] : vector<2xi32>
+    // CHECK: %[[PERM0:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM0]] : (i32, i32) -> i32
+    // CHECK: %[[PERM1:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM1]] : (i32, i32) -> i32
+    // CHECK: %[[V0:.*]] = llvm.mlir.poison : vector<2xi32>
+    // CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[V1:.*]] = llvm.insertelement %[[PERM0]], %[[V0]][%[[IDX0]] : i32] : vector<2xi32>
+    // CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[V2:.*]] = llvm.insertelement %[[PERM1]], %[[V1]][%[[IDX1]] : i32] : vector<2xi32>
+    // CHECK: %[[RES:.*]] = llvm.bitcast %[[V2]] : vector<2xi32> to vector<4xf16>
+    // CHECK: llvm.return %[[RES]] : vector<4xf16>
+    %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<4xf16>
+    func.return %shfl : vector<4xf16>
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 16148a493ce6e..ce1be7b5618fe 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -367,7 +367,7 @@ func.func @subgroup_reduce_cluster_stride_without_size(%arg0 : vector<4xf32>) {
 // -----
 
 func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
-  // expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
+  // expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or fixed-length vector of}}
   %res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
   return
 }
@@ -375,7 +375,7 @@ func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
 // -----
 
 func.func @subgroup_reduce_bad_type_scalable(%arg0 : vector<[2]xf32>) {
-  // expected-error@+1 {{is not compatible with scalable vector types}}
+  // expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or fixed-length vector of}}
   %res = gpu.subgroup_reduce add %arg0 : (vector<[2]xf32>) -> vector<[2]xf32>
   return
 }
@@ -463,13 +463,21 @@ func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
 // -----
 
 func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
-  // expected-error@+1 {{op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1, but got 'index'}}
+  // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
   %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index
   return
 }
 
 // -----
 
+func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %arg2 : i32) {
+  // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
+  %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<[4]xf32>
+  return
+}
+
+// -----
+
 module {
   gpu.module @gpu_funcs {
     // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}

Copy link
Member

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one! :)

@Hardcode84 Hardcode84 merged commit 46d1cb8 into llvm:main Apr 18, 2025
14 checks passed
@Hardcode84 Hardcode84 deleted the shuffle_repack branch April 18, 2025 16:53
@kazutakahirata
Copy link
Contributor

@Hardcode84 I've landed 4c17a5c to fix a warning from this patch. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants