Skip to content

Commit

Permalink
[mlir][ROCDL][~NFC] Migrate to LLVM dialect default builders (llvm#12…
Browse files Browse the repository at this point in the history
…5609)

There were a bunch of spots in ROCDL.td where we were defining our own
llvmBuilder call which could have been generated using the default
built-in one on LLVM_IntrOpBase.

This commit cleans up such usages in the interests of potentinally
enabling ROCDL import in the future and of making best practices more
obvious.

The one breaking change is renaming WaitcntOp to SWaitcntOp, which
should have minimal impact.
  • Loading branch information
krzysz00 authored Feb 6, 2025
1 parent 5812d0b commit efd0a7f
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 151 deletions.
172 changes: 41 additions & 131 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
overloadedOperands, traits, numResults, requiresAccessGroup,
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;

// Subclass to save typing and ease readibility when there aren't overloaded
// operands or memory accesses.
class ROCDL_ConcreteNonMemIntrOp<string mnemonic, list<Trait> traits,
int numResults, list<int> immArgPositions = [],
list<string> immArgNames = []>
: ROCDL_IntrOp<mnemonic, [], [], traits, numResults, 0, 0,
immArgPositions, immArgNames>;
//===----------------------------------------------------------------------===//
// ROCDL special register op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -150,37 +157,26 @@ class ROCDL_MbcntOp<string mnemonic> :
def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;

def ROCDL_DsSwizzleOp :
ROCDL_Op<"ds_swizzle">,
Results<(outs I32:$res)>,
Arguments<(ins I32:$src,
I32:$offset)>
{
string llvmBuilder = [{
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_swizzle, {$src, $offset});
}];
def ROCDL_DsSwizzleOp : ROCDL_ConcreteNonMemIntrOp<"ds_swizzle", [], 1>,
Arguments<(ins I32:$src,
I32:$offset)> {
let results = (outs I32:$res);
let assemblyFormat = [{
$src `,` $offset attr-dict `:` `(` type($src) `,` type($offset) `)` `->` type($res)
}];
}

def ROCDL_DsBpermuteOp :
ROCDL_Op<"ds_bpermute">,
Results<(outs I32:$res)>,
Arguments<(ins I32:$index,
I32:$src)>
{
string llvmBuilder = [{
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_bpermute, {$index, $src});
}];
def ROCDL_DsBpermuteOp : ROCDL_ConcreteNonMemIntrOp<"ds_bpermute", [], 1>,
Arguments<(ins I32:$index,
I32:$src)> {
let results = (outs I32:$res);
let assemblyFormat = [{
$index `,` $src attr-dict `:` `(` type($index) `,` type($src) `)` `->` type($res)
}];
}

def ROCDL_BallotOp :
ROCDL_Op<"ballot">,
Results<(outs LLVM_Type:$res)>,
ROCDL_IntrOp<"ballot", [0], [], [], 1>,
Arguments<(ins I1:$pred)> {
let summary = "Vote across thread group";

Expand All @@ -189,11 +185,6 @@ def ROCDL_BallotOp :
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];

string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_ballot, {$pred}, {$_resultType});
}];

let assemblyFormat = "$pred attr-dict `:` type($res)";
}

Expand Down Expand Up @@ -249,18 +240,12 @@ def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z",

// Emits the waintcnt instruction. The bitfield's semantics depend
// on the target chipset
def ROCDL_WaitcntOp : ROCDL_Op<"waitcnt">, Arguments<(ins I32Attr:$bitfield)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_waitcnt,
{builder.getInt32($bitfield)});
}];
def ROCDL_SWaitcntOp : ROCDL_ConcreteNonMemIntrOp<"s.waitcnt", [], 0, [0], ["bitfield"]>,
Arguments<(ins I32Attr:$bitfield)> {
let assemblyFormat = "attr-dict $bitfield";
}

def ROCDL_SBarrierOp : ROCDL_Op<"s.barrier"> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier);
}];
def ROCDL_SBarrierOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier", [], 0> {
let assemblyFormat = "attr-dict";
}

Expand All @@ -276,68 +261,51 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
let assemblyFormat = "attr-dict";
}

def ROCDL_BarrierSignalOp : ROCDL_IntrOp<"s.barrier.signal", [], [], [], 0, 0, 0, [0], ["id"]>,
def ROCDL_BarrierSignalOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.signal", [], 0, [0], ["id"]>,
Arguments<(ins I32Attr:$id)> {
let results = (outs);
let assemblyFormat = "$id attr-dict";
}

def ROCDL_BarrierWaitOp : ROCDL_IntrOp<"s.barrier.wait", [], [], [], 0, 0, 0, [0], ["id"]>,
def ROCDL_BarrierWaitOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.wait", [], 0, [0], ["id"]>,
Arguments<(ins I16Attr:$id)> {
let results = (outs);
let assemblyFormat = "$id attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier_wait,builder.getInt16(op.getId()));";
}

def ROCDL_WaitDscntOp: ROCDL_IntrOp<"s.wait.dscnt", [], [], [], 0, 0, 0, [0], ["id"]>,
def ROCDL_WaitDscntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.dscnt", [], 0, [0], ["id"]>,
Arguments<(ins I16Attr:$id)> {
let results = (outs);
let assemblyFormat = "$id attr-dict";
}

def ROCDL_SetPrioOp : ROCDL_IntrOp<"s.setprio", [], [], [], 0>,
def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>,
Arguments<(ins I16Attr:$priority)> {
let results = (outs);
let assemblyFormat = "$priority attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_setprio,builder.getInt16(op.getPriority()));";
}

def ROCDL_SchedBarrier : ROCDL_IntrOp<"sched.barrier", [], [], [], 0>,
def ROCDL_SchedBarrier : ROCDL_ConcreteNonMemIntrOp<"sched.barrier", [], 0, [0],["mask"]>,
Arguments<(ins I32Attr:$mask)> {
let results = (outs);
let assemblyFormat = "$mask attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_sched_barrier,builder.getInt32(op.getMask()));";
}

def ROCDL_SchedGroupBarrier : ROCDL_IntrOp<"sched.group.barrier", [], [], [], 0>,
Arguments<(ins I32Attr:$mask, I32Attr:$size, I32Attr:$groupId)> {
let results = (outs);
def ROCDL_SchedGroupBarrier
: ROCDL_ConcreteNonMemIntrOp<"sched.group.barrier", [], 0,
[0, 1, 2], ["mask", "size", "groupId"]>,
Arguments<(ins I32Attr:$mask, I32Attr:$size, I32Attr:$groupId)> {
let assemblyFormat = "$mask `,` $size `,` $groupId attr-dict";
string llvmBuilder = [{
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_sched_group_barrier,
{builder.getInt32(op.getMask()), builder.getInt32(op.getSize()), builder.getInt32(op.getGroupId())});
}];
}

def ROCDL_IglpOpt : ROCDL_IntrOp<"iglp.opt", [], [], [], 0>,
def ROCDL_IglpOpt : ROCDL_ConcreteNonMemIntrOp<"iglp.opt", [], 0, [0], ["variant"]>,
Arguments<(ins I32Attr:$variant)> {
let results = (outs);
let assemblyFormat = "$variant attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_iglp_opt,builder.getInt32(op.getVariant()));";
}

//===---------------------------------------------------------------------===//
// Xdlops intrinsics

class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".","_", mnemonic),
[], [], traits, 1>,
ROCDL_IntrOp<mnemonic, [], [], traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
Expand All @@ -347,9 +315,7 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
// MFMA intrinsics with overloaded operands
class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
list<Trait> traits = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".","_", mnemonic),
[], overloadedOperands, traits, 1>,
ROCDL_IntrOp<mnemonic, [], overloadedOperands, traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
Expand Down Expand Up @@ -430,9 +396,7 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
// WMMA intrinsics
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
list<Trait> traits = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".","_", mnemonic),
[0], overloadedOperands, traits, 1>,
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
Expand Down Expand Up @@ -572,50 +536,32 @@ def ROCDL_RawPtrBufferAtomicFaddOp : ROCDL_RawPtrBufferAtomicNoRet<"fadd">;
// Raw buffer load/store intrinsics

def ROCDL_RawBufferLoadOp :
ROCDL_Op<"raw.buffer.load">,
Results<(outs LLVM_Type:$res)>,
ROCDL_IntrOp<"raw.buffer.load", [0], [], [], 1>,
Arguments<(ins LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)> {
string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_load, {$rsrc, $offset,
$soffset, $aux}, {$_resultType});
}];
let hasCustomAssemblyFormat = 1;
}

def ROCDL_RawBufferStoreOp :
ROCDL_Op<"raw.buffer.store">,
ROCDL_IntrOp<"raw.buffer.store", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_store, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

def ROCDL_RawBufferAtomicCmpSwap :
ROCDL_Op<"raw.buffer.atomic.cmpswap", [AllTypesMatch<["res", "src", "cmp"]>]>,
Results<(outs LLVM_Type:$res)>,
ROCDL_IntrOp<"raw.buffer.atomic.cmpswap", [], [0], [AllTypesMatch<["res", "src", "cmp"]>], 1>,
Arguments<(ins LLVM_Type:$src,
LLVM_Type:$cmp,
LLVM_Type:$rsrc,
I32:$offset,
I32:$soffset,
I32:$aux)>{
string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_cmpswap, {$src, $cmp, $rsrc,
$offset, $soffset, $aux}, {$_resultType});
}];
let assemblyFormat = [{
attr-dict `(` operands `)` `:` type($res) `,` type($rsrc)
}];
Expand All @@ -625,100 +571,64 @@ def ROCDL_RawBufferAtomicCmpSwap :
// MI-100 and MI-200 buffer atomic floating point add intrinsic

def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
ROCDL_IntrOp<"raw.buffer.atomic.fadd", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_fadd, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.

def ROCDL_RawBufferAtomicFMaxOp :
ROCDL_Op<"raw.buffer.atomic.fmax">,
ROCDL_IntrOp<"raw.buffer.atomic.fmax", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_fmax, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// Buffer atomic signed integer max intrinsic.

def ROCDL_RawBufferAtomicSMaxOp :
ROCDL_Op<"raw.buffer.atomic.smax">,
ROCDL_IntrOp<"raw.buffer.atomic.smax", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_smax, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// Buffer atomic unsigned integer min intrinsic.

def ROCDL_RawBufferAtomicUMinOp :
ROCDL_Op<"raw.buffer.atomic.umin">,
ROCDL_IntrOp<"raw.buffer.atomic.umin", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_umin, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

// DPP Update intrinsic
def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
[AllTypesMatch<["res", "src", "old"]>], 1>,
[AllTypesMatch<["res", "src", "old"]>], 1, 0, 0,
[2, 3, 4, 5], ["dppCtrl", "rowMask", "bankMask", "boundCtrl"]>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
}];
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
llvm::Value *args[] = {
moduleTranslation.lookupValue(op.getOld()),
moduleTranslation.lookupValue(op.getSrc()),
builder.getInt32(op.getDppCtrl()),
builder.getInt32(op.getRowMask()),
builder.getInt32(op.getBankMask()),
builder.getInt1(op.getBoundCtrl())
};
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
}];
}

//===---------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;

Location loc = op->getLoc();
rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
Expand Down
Loading

0 comments on commit efd0a7f

Please sign in to comment.