Skip to content

[mlir][gpu] Add gpu.rotate operation #142796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Hsiangkai
Copy link
Contributor

Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V OpGroupNonUniformRotateKHR.

Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V
OpGroupNonUniformRotateKHR.
@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Hsiangkai Wang (Hsiangkai)

Changes

Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V OpGroupNonUniformRotateKHR.


Full diff: https://github.com/llvm/llvm-project/pull/142796.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+43)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+36-1)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+45)
  • (added) mlir/test/Conversion/GPUToSPIRV/rotate.mlir (+26)
  • (modified) mlir/test/Dialect/GPU/invalid.mlir (+78)
  • (modified) mlir/test/Dialect/GPU/ops.mlir (+4)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 15b14c767b66a..46bd6039657bd 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1364,6 +1364,49 @@ def GPU_ShuffleOp : GPU_Op<
   ];
 }
 
+def GPU_RotateOp : GPU_Op<
+    "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
+    Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
+    Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
+  let summary = "Rotate values within a subgroup.";
+  let description = [{
+    The "rotate" op moves values across lanes (a.k.a., invocations, work items)
+    within the same subgroup. The `width` argument specifies the number of lanes
+    that participate in the rotation, and must be uniform across all lanes.
+    Further, the first `width` lanes of the subgroup must be active.
+
+    `width` must be a power of two, and `offset` must be in the range
+    `[0, width)`.
+
+    Return the `rotateResult` of the invocation whose id within the group is
+    calculated as follows:
+
+    Invocation ID = ((LocalId + Delta) & (width - 1)) + (LocalId & ~(width - 1))
+
+    Returns the `rotateResult` if the current lane id is smaller than `width`.
+
+    example:
+
+    ```mlir
+    %cst1 = arith.constant 1 : i32
+    %1 = gpu.rotate %0, %cst1, %width : f32
+    ```
+
+    For lane `k`, returns the value from lane `(k + cst1) % width`.
+  }];
+
+  let assemblyFormat = [{
+    $value `,` $offset `,` $width attr-dict `:` type($value)
+  }];
+
+  let builders = [
+    // Helper function that creates a rotate with constant offset/width.
+    OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
+  ];
+
+  let hasVerifier = 1;
+}
+
 def GPU_BarrierOp : GPU_Op<"barrier"> {
   let summary = "Synchronizes all work items of a workgroup.";
   let description = [{
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 78e6ebb523a46..546705244f35c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
+class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,31 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Rotate
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPURotateConversion::matchAndRewrite(
+    gpu::RotateOp rotateOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
+  unsigned subgroupSize =
+      targetEnv.getAttr().getResourceLimits().getSubgroupSize();
+  IntegerAttr widthAttr;
+  if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
+      widthAttr.getValue().getZExtValue() > subgroupSize)
+    return rewriter.notifyMatchFailure(
+        rotateOp, "rotate width is larger than target subgroup size");
+
+  Location loc = rotateOp.getLoc();
+  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+  Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+      loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
+
+  rewriter.replaceOp(rotateOp, result);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group ops
 //===----------------------------------------------------------------------===//
@@ -733,7 +768,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
-      GPUReturnOpConversion, GPUShuffleConversion,
+      GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 39f626b558294..a9a9473a1c333 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1331,6 +1331,51 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
         mode);
 }
 
+//===----------------------------------------------------------------------===//
+// RotateOp
+//===----------------------------------------------------------------------===//
+
+void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     int32_t offset, int32_t width) {
+  build(builder, result, value,
+        builder.create<arith::ConstantOp>(result.location,
+                                          builder.getI32IntegerAttr(offset)),
+        builder.create<arith::ConstantOp>(result.location,
+                                          builder.getI32IntegerAttr(width)));
+}
+
+LogicalResult RotateOp::verify() {
+  llvm::APInt offsetValue;
+  if (auto constOp = getOffset().getDefiningOp<arith::ConstantOp>()) {
+    if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+      offsetValue = intAttr.getValue();
+    } else {
+      return emitOpError() << "offset is not an integer value";
+    }
+  } else {
+    return emitOpError() << "offset is not a constant value";
+  }
+
+  llvm::APInt widthValue;
+  if (auto constOp = getWidth().getDefiningOp<arith::ConstantOp>()) {
+    if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+      widthValue = intAttr.getValue();
+    } else {
+      return emitOpError() << "width is not an integer value";
+    }
+  } else {
+    return emitOpError() << "width is not a constant value";
+  }
+
+  if (!widthValue.isPowerOf2())
+    return emitOpError() << "width must be a power of two";
+
+  if (offsetValue.sge(widthValue) || offsetValue.slt(0))
+    return emitOpError() << "offset must be in the range [0, width)";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BarrierOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
new file mode 100644
index 0000000000000..e0dd14d87d42f
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @rotate()
+  gpu.func @rotate() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 16 : i32
+    %val = arith.constant 42.0 : f32
+
+    // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+    // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
+    // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    %result = gpu.rotate %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index ce1be7b5618fe..0ad5690f5cf70 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
 
 // -----
 
+func.func @rotate_mismatching_type(%arg0 : f32) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 16 : i32
+  // expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> i32
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_type(%arg0 : index) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 16 : i32
+  // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
+  %shfl = gpu.rotate %arg0, %offset, %width : index
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 16 : i32
+  // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
+  %shfl = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_width(%arg0 : f32) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 15 : i32
+  // expected-error@+1 {{op width must be a power of two}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_offset(%arg0 : f32) {
+  %offset = arith.constant 16 : i32
+  %width = arith.constant 16 : i32
+  // expected-error@+1 {{op offset must be in the range [0, width)}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
+  %offset = arith.constant -1 : i32
+  %width = arith.constant 16 : i32
+  // expected-error@+1 {{op offset must be in the range [0, width)}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
+  %width = arith.constant 16 : i32
+  // expected-error@+1 {{op offset is not a constant value}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
+  %offset = arith.constant 0 : i32
+  // expected-error@+1 {{op width is not a constant value}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
 module {
   gpu.module @gpu_funcs {
     // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 9dbe16774f517..4beb8ffa09ac6 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
       // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
       %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
 
+      // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
+      %rotate_width = arith.constant 16 : i32
+      %rotate = gpu.rotate %arg0, %offset, %rotate_width : f32
+
       "gpu.barrier"() : () -> ()
 
       "some_op"(%bIdX, %tIdX) : (index, index) -> ()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants