Skip to content

[MLIR][AMDGPU] Fix bug in GatherToLDSOpLowering, get the correct MemRefType for destination #142915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/docs/AMDGPUUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1215,12 +1215,12 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
denormalization mode, enabled traps, and floating point exceptions.
The format is a 64-bit concatenation of the MODE and TRAPSTS registers.

:ref:`llvm.set.fpenv<int_set_fpenv>` Sets the floating point environment to the specifies state.
:ref:`llvm.set.fpenv<int_set_fpenv>` Sets the floating point environment to the specified state.
llvm.amdgcn.load.to.lds.p<1/7> Loads values from global memory (either in the form of a global
a raw fat buffer pointer) to LDS. The size of the data copied can be 1, 2,
or 4 bytes (and gfx950 also allows 12 or 16 bytes). The LDS pointer
argument should be wavefront-uniform; the global pointer need not be.
The LDS pointer is implicitly offset by 4 * lane_id bytes for sies <= 4 bytes
The LDS pointer is implicitly offset by 4 * lane_id bytes for size <= 4 bytes
and 16 * lane_id bytes for larger sizes. This lowers to `global_load_lds`,
`buffer_load_* ... lds`, or `global_load__* ... lds` depending on address
space and architecture. `amdgcn.global.load.lds` has the same semantics as
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
Location loc = op.getLoc();

auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());

// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
Expand Down
20 changes: 11 additions & 9 deletions mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
Expand Down Expand Up @@ -65,8 +65,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
Expand Down Expand Up @@ -103,8 +103,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[C128:.*]] = llvm.mlir.constant(128 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C128]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
Expand All @@ -130,7 +130,9 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[DSTIDX:.*]] = llvm.mul %[[DSTIDX_CAST]], %[[C64]] : i64
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX]]]
// CHECK: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -166,8 +168,8 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu_fat
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
Expand Down
Loading