Skip to content

[mlir][gpu] Add gpu.rotate operation #142796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,49 @@ def GPU_ShuffleOp : GPU_Op<
];
}

def GPU_RotateOp : GPU_Op<
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
let summary = "Rotate values within a subgroup.";
let description = [{
The "rotate" op moves values across lanes (a.k.a., invocations, work items)
within the same subgroup. The `width` argument specifies the number of lanes
that participate in the rotation, and must be uniform across all lanes.
Further, the first `width` lanes of the subgroup must be active.

`width` must be a power of two, and `offset` must be in the range
`[0, width)`.

Return the `rotateResult` of the invocation whose id within the group is
calculated as follows:

Invocation ID = ((LocalId + Delta) & (width - 1)) + (LocalId & ~(width - 1))

Returns the `rotateResult` if the current lane id is smaller than `width`.

example:

```mlir
%cst1 = arith.constant 1 : i32
%1 = gpu.rotate %0, %cst1, %width : f32
```

For lane `k`, returns the value from lane `(k + cst1) % width`.
}];

let assemblyFormat = [{
$value `,` $offset `,` $width attr-dict `:` type($value)
}];

let builders = [
// Helper function that creates a rotate with constant offset/width.
OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
];

let hasVerifier = 1;
}

def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{
Expand Down
37 changes: 36 additions & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -458,6 +468,31 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return success();
}

//===----------------------------------------------------------------------===//
// Rotate
//===----------------------------------------------------------------------===//

LogicalResult GPURotateConversion::matchAndRewrite(
gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
IntegerAttr widthAttr;
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
widthAttr.getValue().getZExtValue() > subgroupSize)
return rewriter.notifyMatchFailure(
rotateOp, "rotate width is larger than target subgroup size");

Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());

rewriter.replaceOp(rotateOp, result);
return success();
}

//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -733,7 +768,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
Expand Down
45 changes: 45 additions & 0 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,51 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
mode);
}

//===----------------------------------------------------------------------===//
// RotateOp
//===----------------------------------------------------------------------===//

void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
int32_t offset, int32_t width) {
build(builder, result, value,
builder.create<arith::ConstantOp>(result.location,
builder.getI32IntegerAttr(offset)),
builder.create<arith::ConstantOp>(result.location,
builder.getI32IntegerAttr(width)));
}

LogicalResult RotateOp::verify() {
llvm::APInt offsetValue;
if (auto constOp = getOffset().getDefiningOp<arith::ConstantOp>()) {
if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
offsetValue = intAttr.getValue();
} else {
return emitOpError() << "offset is not an integer value";
}
} else {
return emitOpError() << "offset is not a constant value";
}

llvm::APInt widthValue;
if (auto constOp = getWidth().getDefiningOp<arith::ConstantOp>()) {
if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
widthValue = intAttr.getValue();
} else {
return emitOpError() << "width is not an integer value";
}
} else {
return emitOpError() << "width is not a constant value";
}

if (!widthValue.isPowerOf2())
return emitOpError() << "width must be a power of two";

if (offsetValue.sge(widthValue) || offsetValue.slt(0))
return emitOpError() << "offset must be in the range [0, width)";

return success();
}

//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/rotate.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate()
gpu.func @rotate() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
%result = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}

}
78 changes: 78 additions & 0 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a

// -----

func.func @rotate_mismatching_type(%arg0 : f32) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
%shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> i32
return
}

// -----

func.func @rotate_unsupported_type(%arg0 : index) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
%shfl = gpu.rotate %arg0, %offset, %width : index
return
}

// -----

func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
%shfl = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
return
}

// -----

func.func @rotate_unsupported_width(%arg0 : f32) {
%offset = arith.constant 4 : i32
%width = arith.constant 15 : i32
// expected-error@+1 {{op width must be a power of two}}
%shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
return
}

// -----

func.func @rotate_unsupported_offset(%arg0 : f32) {
%offset = arith.constant 16 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, width)}}
%shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
return
}

// -----

func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
%offset = arith.constant -1 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, width)}}
%shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
return
}

// -----

func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset is not a constant value}}
%shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
return
}

// -----

func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
%offset = arith.constant 0 : i32
// expected-error@+1 {{op width is not a constant value}}
%shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
return
}

// -----

module {
gpu.module @gpu_funcs {
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/Dialect/GPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32

// CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
%rotate_width = arith.constant 16 : i32
%rotate = gpu.rotate %arg0, %offset, %rotate_width : f32

"gpu.barrier"() : () -> ()

"some_op"(%bIdX, %tIdX) : (index, index) -> ()
Expand Down
Loading