Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,23 @@ def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
let assemblyFormat = "`<` $value `>`";
}

// Num CTAs in a group participating in the TMA/MMA operations.
// This corresponds to the "cta_group::1", "cta_group::2"
// modifiers in the PTX instructions.
def CTAGroup_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
def CTAGroup_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;

def CTAGroupKind : I32EnumAttr<"CTAGroupKind",
"NVVM CTA group kind",
[CTAGroup_1, CTAGroup_2]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def CTAGroupKindAttr :
EnumAttr<NVVM_Dialect, CTAGroupKind, "cta_group"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
Expand Down Expand Up @@ -3333,23 +3350,6 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
//===----------------------------------------------------------------------===//
// NVVM TCGEN05 Ops
//===----------------------------------------------------------------------===//
// Num CTAs in a group participating in the TCGEN05 operation.
// This corresponds to the "cta_group::1", "cta_group::2"
// modifiers in the PTX instructions.
def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;

def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
"NVVM Tcgen05 group kind",
[Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def Tcgen05GroupKindAttr :
EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
let assemblyFormat = "`<` $value `>`";
}

def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">;
def Tcgen05FenceAfter : I32EnumAttrCase<"AFTER_THREAD_SYNC", 1, "after">;
def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind",
Expand Down Expand Up @@ -3387,7 +3387,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]
let arguments = (ins
AnyTypeOf<[LLVM_AnyPointer, LLVM_PointerShared]>:$addr,
I32:$nCols,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
DefaultValuedAttr<CTAGroupKindAttr, "CTAGroupKind::CTA_1">:$group);

let assemblyFormat = "$addr `,` $nCols attr-dict `:` type(operands)";

Expand Down Expand Up @@ -3415,7 +3415,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 10
}];

let arguments = (ins LLVM_PointerTensor:$taddr, I32:$nCols,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
DefaultValuedAttr<CTAGroupKindAttr, "CTAGroupKind::CTA_1">:$group);

let assemblyFormat = "$taddr `,` $nCols attr-dict `:` type(operands)";

Expand Down Expand Up @@ -3443,12 +3443,12 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
}];

let arguments = (ins
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
DefaultValuedAttr<CTAGroupKindAttr, "CTAGroupKind::CTA_1">:$group);

let assemblyFormat = "attr-dict";

string llvmBuilder = [{
auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
auto id = ($group == NVVM::CTAGroupKind::CTA_1) ?
llvm::Intrinsic::nvvm_tcgen05_relinq_alloc_permit_cg1 :
llvm::Intrinsic::nvvm_tcgen05_relinq_alloc_permit_cg2;
createIntrinsicCall(builder, id);
Expand Down Expand Up @@ -3516,7 +3516,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]
let arguments = (ins
AnyTypeOf<[LLVM_AnyPointer, LLVM_PointerShared]>:$addr,
Optional<I16>:$multicastMask,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
DefaultValuedAttr<CTAGroupKindAttr, "CTAGroupKind::CTA_1">:$group);

let assemblyFormat = [{
$addr (`,` `multicast_mask` `=` $multicastMask^)?
Expand Down Expand Up @@ -3549,12 +3549,12 @@ def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 1
}];

let arguments = (ins LLVM_PointerTensor:$taddr,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
DefaultValuedAttr<CTAGroupKindAttr, "CTAGroupKind::CTA_1">:$group);

let assemblyFormat = "$taddr attr-dict `:` type(operands)";

string llvmBuilder = [{
auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
auto id = ($group == NVVM::CTAGroupKind::CTA_1) ?
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
createIntrinsicCall(builder, id, {$taddr});
Expand Down Expand Up @@ -3626,7 +3626,7 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {

let arguments = (ins
Tcgen05CpShapeAttr:$shape,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
DefaultValuedAttr<CTAGroupKindAttr, "CTAGroupKind::CTA_1">:$group,
DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
LLVM_PointerTensor:$taddr,
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;

llvm::Intrinsic::ID id;
if (isShared) {
Expand All @@ -1819,7 +1819,7 @@ llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
: llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;

Expand Down Expand Up @@ -1847,7 +1847,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;

llvm::Intrinsic::ID id =
is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
Expand Down Expand Up @@ -1879,7 +1879,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,

llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
auto srcFmt = curOp.getSrcFormat();
auto mc = curOp.getMulticast();

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-alloc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ llvm.func @llvm_nvvm_tcgen05_alloc(%addr : !llvm.ptr, %ncols : i32) {
nvvm.tcgen05.alloc %addr, %ncols : !llvm.ptr, i32

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.cg2(ptr %{{.*}}, i32 %{{.*}})
nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr, i32
nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.cta_group<cta_2>} : !llvm.ptr, i32
llvm.return
}

Expand All @@ -16,7 +16,7 @@ llvm.func @llvm_nvvm_tcgen05_alloc_shared(%addr : !llvm.ptr<3>, %ncols : i32) {
nvvm.tcgen05.alloc %addr, %ncols : !llvm.ptr<3>, i32

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.shared.cg2(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<3>, i32
nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.cta_group<cta_2>} : !llvm.ptr<3>, i32
llvm.return
}

Expand All @@ -26,7 +26,7 @@ llvm.func @llvm_nvvm_tcgen05_dealloc(%addr : !llvm.ptr<6>, %ncols : i32) {
nvvm.tcgen05.dealloc %addr, %ncols : !llvm.ptr<6>, i32

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.dealloc.cg2(ptr addrspace(6) %{{.*}}, i32 %{{.*}})
nvvm.tcgen05.dealloc %addr, %ncols {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>, i32
nvvm.tcgen05.dealloc %addr, %ncols {group = #nvvm.cta_group<cta_2>} : !llvm.ptr<6>, i32
llvm.return
}

Expand All @@ -36,6 +36,6 @@ llvm.func @llvm_nvvm_tcgen05_relinquish_alloc_permit() {
nvvm.tcgen05.relinquish_alloc_permit

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.relinq.alloc.permit.cg2()
nvvm.tcgen05.relinquish_alloc_permit {group = #nvvm.tcgen05_group<cta_2>}
nvvm.tcgen05.relinquish_alloc_permit {group = #nvvm.cta_group<cta_2>}
llvm.return
}
8 changes: 4 additions & 4 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-commit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ llvm.func @llvm_nvvm_tcgen05_commit_generic(%barrier : !llvm.ptr, %cta_mask : i1
nvvm.tcgen05.commit %barrier : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.cg2(ptr %{{.*}})
nvvm.tcgen05.commit %barrier {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr
nvvm.tcgen05.commit %barrier {group = #nvvm.cta_group<cta_2>} : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.cg1(ptr %{{.*}}, i16 %{{.*}})
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask : !llvm.ptr, i16

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.cg2(ptr %{{.*}}, i16 %{{.*}})
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr, i16
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.cta_group<cta_2>} : !llvm.ptr, i16
llvm.return
}

Expand All @@ -22,12 +22,12 @@ llvm.func @llvm_nvvm_tcgen05_commit_shared(%barrier : !llvm.ptr<3>, %cta_mask :
nvvm.tcgen05.commit %barrier : !llvm.ptr<3>

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.shared.cg2(ptr addrspace(3) %{{.*}})
nvvm.tcgen05.commit %barrier {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<3>
nvvm.tcgen05.commit %barrier {group = #nvvm.cta_group<cta_2>} : !llvm.ptr<3>

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.shared.cg1(ptr addrspace(3) %{{.*}}, i16 %{{.*}})
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask : !llvm.ptr<3>, i16

// CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.shared.cg2(ptr addrspace(3) %{{.*}}, i16 %{{.*}})
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<3>, i16
nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.cta_group<cta_2>} : !llvm.ptr<3>, i16
llvm.return
}
30 changes: 15 additions & 15 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.cta_group<cta_2>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
llvm.return
Expand All @@ -29,18 +29,18 @@ llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.cta_group<cta_2>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
llvm.return
Expand All @@ -52,18 +52,18 @@ llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.cta_group<cta_2>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
llvm.return
Expand All @@ -80,21 +80,21 @@ llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
}

// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
group = #nvvm.tcgen05_group<cta_1>,
group = #nvvm.cta_group<cta_1>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
Expand All @@ -113,21 +113,21 @@ llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
}

// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
group = #nvvm.tcgen05_group<cta_2>,
group = #nvvm.cta_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
group = #nvvm.tcgen05_group<cta_1>,
group = #nvvm.cta_group<cta_1>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
nvvm.tcgen05.shift %taddr : !llvm.ptr<6>

// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
nvvm.tcgen05.shift %taddr {group = #nvvm.cta_group<cta_2>} : !llvm.ptr<6>
llvm.return
}
Loading