Skip to content

[MLIR][AMDGPU] Fix bug in GatherToLDSOpLowering, get the correct MemRefType for destination #142915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Jun 5, 2025

This PR fixes a bug in GatherToLDSOpLowering, we were getting the MemRefType of source for the destination. Additionally, some related typos are corrected.

CC: @krzysz00 @umangyadav @lialan

…efType for destination

This PR fixes a bug in GatherToLDSOpLowering, we were getting the MemRefType of source
for the destionation. Additionally, some related typos are corrected.
@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Daniel Hernandez-Juarez (dhernandez0)

Changes

This PR fixes a bug in GatherToLDSOpLowering, we were getting the MemRefType of source for the destination. Additionally, some related typos are corrected.


Full diff: https://github.com/llvm/llvm-project/pull/142915.diff

3 Files Affected:

  • (modified) llvm/docs/AMDGPUUsage.rst (+2-2)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+1-1)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir (+11-9)
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index 174a497c51b26..003454d274044 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1215,12 +1215,12 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
                                                    denormalization mode, enabled traps, and floating point exceptions.
                                                    The format is a 64-bit concatenation of the MODE and TRAPSTS registers.
 
-  :ref:`llvm.set.fpenv<int_set_fpenv>`             Sets the floating point environment to the specifies state.
+  :ref:`llvm.set.fpenv<int_set_fpenv>`             Sets the floating point environment to the specified state.
   llvm.amdgcn.load.to.lds.p<1/7>                   Loads values from global memory (either in the form of a global
                                                    a raw fat buffer pointer) to LDS. The size of the data copied can be 1, 2,
                                                    or 4 bytes (and gfx950 also allows 12 or 16 bytes). The LDS pointer
                                                    argument should be wavefront-uniform; the global pointer need not be.
-                                                   The LDS pointer is implicitly offset by 4 * lane_id bytes for sies <= 4 bytes
+                                                   The LDS pointer is implicitly offset by 4 * lane_id bytes for size <= 4 bytes
                                                    and 16 * lane_id bytes for larger sizes. This lowers to `global_load_lds`,
                                                    `buffer_load_* ... lds`, or `global_load__* ... lds` depending on address
                                                    space and architecture. `amdgcn.global.load.lds` has the same semantics as
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..8290130933db3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,7 +1100,7 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     Location loc = op.getLoc();
 
     auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
-    auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     // TODO: instead of only transfering one element per thread, we could
     // augment it to transfer multiple elements per thread by issuing multiple
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index cb3539dd11be3..581346e03b893 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -31,8 +31,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -65,8 +65,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -103,8 +103,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C128:.*]] = llvm.mlir.constant(128 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C128]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -130,7 +130,9 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
   // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
-  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[DSTIDX:.*]] = llvm.mul %[[DSTIDX_CAST]], %[[C64]] : i64
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX]]]
   // CHECK: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
   %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
@@ -166,8 +168,8 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu_fat
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Daniel Hernandez-Juarez (dhernandez0)

Changes

This PR fixes a bug in GatherToLDSOpLowering, we were getting the MemRefType of source for the destination. Additionally, some related typos are corrected.


Full diff: https://github.com/llvm/llvm-project/pull/142915.diff

3 Files Affected:

  • (modified) llvm/docs/AMDGPUUsage.rst (+2-2)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+1-1)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir (+11-9)
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index 174a497c51b26..003454d274044 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1215,12 +1215,12 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
                                                    denormalization mode, enabled traps, and floating point exceptions.
                                                    The format is a 64-bit concatenation of the MODE and TRAPSTS registers.
 
-  :ref:`llvm.set.fpenv<int_set_fpenv>`             Sets the floating point environment to the specifies state.
+  :ref:`llvm.set.fpenv<int_set_fpenv>`             Sets the floating point environment to the specified state.
   llvm.amdgcn.load.to.lds.p<1/7>                   Loads values from global memory (either in the form of a global
                                                    a raw fat buffer pointer) to LDS. The size of the data copied can be 1, 2,
                                                    or 4 bytes (and gfx950 also allows 12 or 16 bytes). The LDS pointer
                                                    argument should be wavefront-uniform; the global pointer need not be.
-                                                   The LDS pointer is implicitly offset by 4 * lane_id bytes for sies <= 4 bytes
+                                                   The LDS pointer is implicitly offset by 4 * lane_id bytes for size <= 4 bytes
                                                    and 16 * lane_id bytes for larger sizes. This lowers to `global_load_lds`,
                                                    `buffer_load_* ... lds`, or `global_load__* ... lds` depending on address
                                                    space and architecture. `amdgcn.global.load.lds` has the same semantics as
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..8290130933db3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,7 +1100,7 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     Location loc = op.getLoc();
 
     auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
-    auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     // TODO: instead of only transfering one element per thread, we could
     // augment it to transfer multiple elements per thread by issuing multiple
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index cb3539dd11be3..581346e03b893 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -31,8 +31,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -65,8 +65,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -103,8 +103,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C128:.*]] = llvm.mlir.constant(128 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C128]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
@@ -130,7 +130,9 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
   // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
-  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[DSTIDX:.*]] = llvm.mul %[[DSTIDX_CAST]], %[[C64]] : i64
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX]]]
   // CHECK: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
   %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
@@ -166,8 +168,8 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu_fat
   // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
 
-  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
-  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
   // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
 
   // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]

Copy link
Contributor

@umangyadav umangyadav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants