Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform Flatten Memref Load #7298

Merged
merged 1 commit into from
Oct 31, 2024
Merged

Conversation

jiahanxie353
Copy link
Contributor

Caught a small issue with flattening the load.

Say we have an example:

module {
  func.func @main(%arg0: memref<3x4xi32>) {
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c3 = arith.constant 3 : index
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x2xi32>
    scf.for %arg1 = %c0 to %c3 step %c1 {
      scf.for %arg2 = %c0 to %c2 step %c1 {
        scf.for %arg3 = %c0 to %c4 step %c1 {
          %0 = memref.load %arg0[%arg1, %arg3] : memref<3x4xi32>
          %1 = memref.load %alloc[%arg3, %arg2] : memref<4x2xi32>
          %2 = arith.muli %1, %2 : i32
        }
      }
    }
    return
  }
}

Previously, it gives:

module {
  func.func @main(%arg0: memref<12xi32>) {
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c3 = arith.constant 3 : index
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<8xi32>
    scf.for %arg1 = %c0 to %c3 step %c1 {
      scf.for %arg2 = %c0 to %c2 step %c1 {
        scf.for %arg3 = %c0 to %c4 step %c1 {
          %c3_1 = arith.constant 3 : index
          %1 = arith.muli %arg3, %c3_1 : index
          %2 = arith.addi %arg1, %1 : index
          %3 = memref.load %arg0[%2] : memref<12xi32>
          %c2_2 = arith.constant 2 : index
          %4 = arith.shli %arg2, %c2_2 : index
          %5 = arith.addi %arg3, %4 : index
          %6 = memref.load %alloc[%5] : memref<8xi32>
          %7 = arith.muli %3, %6 : i32
        }
      }
    }
    return
  }
}

It is wrong because it's doing:

  %1 = arith.muli %arg3, %c3_1 : index
  %2 = arith.addi %arg1, %1 : index

And the order of accessing the memory is:

%arg1 = 0, %arg3 = 0 -> access 0 * 3 + 0 = 0;
%arg1 = 0, %arg3 = 1 -> access 1 * 3 + 0 = 0; (oops, wrong because we should access address 1)

What it should be, which is also the result after fixing it, is:

module {
  func.func @main(%arg0: memref<12xi32>) {
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c3 = arith.constant 3 : index
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<8xi32>
    scf.for %arg1 = %c0 to %c3 step %c1 {
      scf.for %arg2 = %c0 to %c2 step %c1 {
        scf.for %arg3 = %c0 to %c4 step %c1 {
          %c2_1 = arith.constant 2 : index
          %1 = arith.shli %arg1, %c2_1 : index
          %2 = arith.addi %1, %arg3 : index
          %3 = memref.load %arg0[%2] : memref<12xi32>
          %c1_2 = arith.constant 1 : index
          %4 = arith.shli %arg3, %c1_2 : index
          %5 = arith.addi %4, %arg2 : index
          %6 = memref.load %alloc[%5] : memref<8xi32>
          %7 = arith.muli %3, %6 : i32
        }
      }
    }
    return
  }
}

Copy link
Member

@dobios dobios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your explanation seems solid, and the implementation is fine, but I don't know enough about this pass to guarantee that this won't break anything. Maybe @mortbopet can comment on this ?

Copy link
Contributor

@mortbopet mortbopet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply; LGTM... this would really benefit from some unit tests to sanity check things! Feel free to merge, but you get extra ⭐'s if you add some form of integration test :)

@jiahanxie353 jiahanxie353 merged commit 31e4f9e into llvm:main Oct 31, 2024
4 checks passed
@jiahanxie353 jiahanxie353 deleted the flatten-load branch October 31, 2024 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants