Skip to content

[mlir][Memref] Add memref-merge optimization #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Sep 9, 2024
Merged

Conversation

Menooker
Copy link

@Menooker Menooker commented May 8, 2024

No description provided.

@Menooker Menooker requested review from ciyongch and ZhennanQin May 8, 2024 08:43
@Menooker
Copy link
Author

Menooker commented May 8, 2024

@Menooker Menooker requested a review from Devjiu May 8, 2024 08:44
@Menooker
Copy link
Author

Menooker commented May 8, 2024

@Devjiu Here in this PR, I have added a unittest in lit. It requires LLVM to be built with LLVM_INSTALL_GTEST=ON

@Menooker Menooker mentioned this pull request May 8, 2024
3 tasks
@Devjiu Devjiu requested a review from kurapov-peter May 8, 2024 11:39
@Devjiu
Copy link
Contributor

Devjiu commented May 8, 2024

@Devjiu Here in this PR, I have added a unittest in lit. It requires LLVM to be built with LLVM_INSTALL_GTEST=ON

I like approach, but pr is a quite big. To change LLVM build options you need to update: .github/workflows/build-llvm.yml as it's done in #39.
Also in doc you are sharing result of bufferization - It's a oneshot I guess.
Maybe it's bettter to have test - that applies bufferization to your original mlp example.

func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x128xf32> {
   %a0 = tensor.empty() : tensor<128x128xf32>
   %a = linalg.matmul ins(%x, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%a0: tensor<128x128xf32>) -> tensor<128x128xf32>
   %b0 = tensor.empty() : tensor<128x128xf32>
   %b = linalg.matmul ins(%a, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%b0: tensor<128x128xf32>) -> tensor<128x128xf32>
   %c0 = tensor.empty() : tensor<128x128xf32>
   %c = linalg.matmul ins(%b, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%c0: tensor<128x128xf32>) -> tensor<128x128xf32>
   %d0 = tensor.empty() : tensor<128x128xf32>
   %d = linalg.matmul ins(%c, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%d0: tensor<128x128xf32>) -> tensor<128x128xf32>
   return %d : tensor<128x128xf32>
}

in my case ./build/src/gc-opt --one-shot-bufferize test/gc-dialects/Linlagx/mlp.mlir :

module {
  func.func @mlp(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
    %0 = bufferization.to_memref %arg1 : memref<128x128xf32, strided<[?, ?], offset: ?>>
    %1 = bufferization.to_memref %arg1 : memref<128x128xf32, strided<[?, ?], offset: ?>>
    %2 = bufferization.to_memref %arg1 : memref<128x128xf32, strided<[?, ?], offset: ?>>
    %3 = bufferization.to_memref %arg1 : memref<128x128xf32, strided<[?, ?], offset: ?>>
    %4 = bufferization.to_memref %arg0 : memref<128x128xf32, strided<[?, ?], offset: ?>>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
    linalg.matmul ins(%4, %3 : memref<128x128xf32, strided<[?, ?], offset: ?>>, memref<128x128xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<128x128xf32>)
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
    linalg.matmul ins(%alloc, %2 : memref<128x128xf32>, memref<128x128xf32, strided<[?, ?], offset: ?>>) outs(%alloc_0 : memref<128x128xf32>)
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
    linalg.matmul ins(%alloc_0, %1 : memref<128x128xf32>, memref<128x128xf32, strided<[?, ?], offset: ?>>) outs(%alloc_1 : memref<128x128xf32>)
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
    linalg.matmul ins(%alloc_1, %0 : memref<128x128xf32>, memref<128x128xf32, strided<[?, ?], offset: ?>>) outs(%alloc_2 : memref<128x128xf32>)
    %5 = bufferization.to_tensor %alloc_2 : memref<128x128xf32>
    return %5 : tensor<128x128xf32>
  }
}

Also in MLIR there are specific methods getUses() and getUsers() that can be used for live ranges in the same way as you do with ticks.
In MLIR there are liveness analysis https://mlir.llvm.org/doxygen/Liveness_8h_source.html - can we reuse it? Your approach is fine, but have you tried to use existing mlir tools?

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of things: I'd suggest creating a similar RFC in the upstream to get the initial feedback. As to the PR itself, I'd split the testing and the pass implementation, just for the sake of reviews.

@Menooker
Copy link
Author

Menooker commented May 9, 2024

In MLIR there are liveness analysis

Thanks for pointing out. I had a brief look into that and I find that MLIR's liveness analysis does not meet our requirements.

  • we need special handling for complex control flow, e.g., to make sure the buffers used in a for-loop has overlapping liveness. MLIR's liveness analysis simply checks the Block-Value-Operation relations
  • Our compile-time memory allocator needs a serialized stream of alloc/dealloc "events". We need to serialize the liveness info from analysis
  • We need to track the memref.view on memref.alloc.

I believe we can indeed build our analysis on MLIR's liveness analysis, but I doubt whether it can reduce the code complexity, since we have some non-standard requirement on it.

@Menooker
Copy link
Author

Menooker commented May 9, 2024

As to the PR itself, I'd split the testing and the pass implementation

Did you mean splitting the the "adding our first unittest to lit" from this PR, or splitting all the testing code from the implementation?

@Menooker
Copy link
Author

Menooker commented May 9, 2024

@Devjiu @kurapov-peter I have opened a new PR #50 for updating the build system to enable unittest in MLIR-related code. Shall we move the discussion on testing to that PR?

In that PR, I have updated .github/workflows/build-llvm.yml. Thanks for the suggestion!

@Menooker Menooker changed the base branch from main to yijie/unittest May 9, 2024 02:34
@Menooker
Copy link
Author

Menooker commented May 9, 2024

Maybe it's bettter to have test - that applies bufferization to your original mlp example.

It is good advice! I missed that part of the test and after adding that, it does exposes a bug. Thanks!

I have now added a .mlir test case for tensor->bufferization->merge-alloc pipeline. It now works as expected.

Mei, Yijie added 2 commits May 9, 2024 15:19

FailureOr<size_t> getAllocSize(Operation *op) {
auto refType = op->getResultTypes().front().cast<MemRefType>();
int64_t size = refType.getElementTypeBitWidth() / 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check whether it works for i1 as it's a boolean type.

@Devjiu
Copy link
Contributor

Devjiu commented May 10, 2024

In MLIR there are liveness analysis

Thanks for pointing out. I had a brief look into that and I find that MLIR's liveness analysis does not meet our requirements.

  • we need special handling for complex control flow, e.g., to make sure the buffers used in a for-loop has overlapping liveness. MLIR's liveness analysis simply checks the Block-Value-Operation relations
  • Our compile-time memory allocator needs a serialized stream of alloc/dealloc "events". We need to serialize the liveness info from analysis
  • We need to track the memref.view on memref.alloc.

I believe we can indeed build our analysis on MLIR's liveness analysis, but I doubt whether it can reduce the code complexity, since we have some non-standard requirement on it.

can we discuss this approach a little bit more?
As told Petr this pass looks generic enough to be upstreamed.

e.g. in bufferization there are pass "OwnershipBasedBufferDeallocation"("ownership-based-buffer-deallocation")
and in it's implementation

FailureOr<Operation *>
BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
  OpBuilder builder = OpBuilder::atBlockBegin(op->getBlock());

  // TODO: the RegionBranchOpInterface does not provide all the necessary
  // methods to perform this transformation without additional assumptions on
  // the structure. In particular, that
  // * additional values to be passed to the next region can be added to the end
  //   of the operand list, the end of the block argument list, and the end of
  //   the result value list. However, it seems to be the general guideline for
  //   operations implementing this interface to follow this structure.
  // * and that the block arguments and result values match the forwarded
  //   operands one-to-one (i.e., that there are no other values appended to the
  //   front).
  // These assumptions are satisfied by the `scf.if`, `scf.for`, and `scf.while`
  // operations.

  SmallVector<RegionSuccessor> regions;
  op.getSuccessorRegions(RegionBranchPoint::parent(), regions);
  assert(!regions.empty() && "Must have at least one successor region");
  SmallVector<Value> entryOperands(
      op.getEntrySuccessorOperands(regions.front()));
  unsigned numMemrefOperands = llvm::count_if(entryOperands, isMemref);

  // No ownership is acquired for any MemRefs that are passed to the region from
  // the outside.
  Value falseVal = buildBoolValue(builder, op.getLoc(), false);
  op->insertOperands(op->getNumOperands(),
                     SmallVector<Value>(numMemrefOperands, falseVal));

  int counter = op->getNumResults();
  unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
  SmallVector<Type> ownershipResults(numMemrefResults, builder.getI1Type());
  RegionBranchOpInterface newOp = appendOpResults(op, ownershipResults);

  for (auto result : llvm::make_filter_range(newOp->getResults(), isMemref)) {
    state.updateOwnership(result, newOp->getResult(counter++));
    state.addMemrefToDeallocate(result, newOp->getBlock());
  }

  return newOp.getOperation();
}

I understand, that you already have working implementation but as we expect most of our code to be upstreamed I think we should try to adopt code to MLIR approaches e.g. use sideEffects, getDefiningOp

BTW I don't fully understand purpose of "test.source" op?

@Menooker
Copy link
Author

can we discuss this approach a little bit more?
As told Petr this pass looks generic enough to be upstreamed.

Sure. We can directly discusses on teams if you wish to. :)

e.g. in bufferization there are pass "OwnershipBasedBufferDeallocation"("ownership-based-buffer-deallocation")
and in it's implementation

Sorry, I don't quite understand how it is related to our pass. "OwnershipBasedBufferDeallocation" doesn't need liveness analysis (in the code you posted). I totally agree that we should use the components already in MLIR. Let's find which one we can reuse in our pass.

BTW I don't fully understand purpose of "test.source" op?

It is a general operation that references a memref. In MLIR, you can write IR code with undefined dialects. And the passes can still work on these undefined operations, as long as we don't need the semantics of this operation. You can view "test.source" op as any operations referencing memref, like memref.load, linalg.matmul, etc.

@Menooker
Copy link
Author

I have submitted a RFC in the community: https://discourse.llvm.org/t/rfc-compile-time-memref-alloc-scheduling-merging-optimization/78872

@ciyongch
Copy link
Contributor

can we discuss this approach a little bit more? As told Petr this pass looks generic enough to be upstreamed.

e.g. in bufferization there are pass "OwnershipBasedBufferDeallocation"("ownership-based-buffer-deallocation") and in it's implementation

Hi @Devjiu , this pass is designed to handle memref and thus run after the bufferization pass, the bufferization pass itself doesn't do such kind of memory optimization, and the community is still in an active transition from conversion-based to one-shot-bufferization(DPS style). The current algorithm is only verified on CPU platform in Graph compiler v1, it might require broader coverage for upstream. I would suggest we move on with this implementation and keep track with upstream feedback to catch up our internal timeline.

@Menooker Menooker changed the base branch from yijie/unittest to main May 15, 2024 06:42
@ciyongch
Copy link
Contributor

ciyongch commented Jul 3, 2024

We might need an issue to track this as well.

@Menooker Menooker linked an issue Aug 15, 2024 that may be closed by this pull request
@Menooker
Copy link
Author

Updated the code. Ready for review.

};

// the memory chunk that is split from another chunk
struct split_chunk_t : public MemoryChunk {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the same coding style for split_chunk_t

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed that. There were too many of them for code-style changes. :)

split_chunk_t(size_t size, MemoryChunk *parent, bool is_lhs)
: MemoryChunk{ChunkType::SPLIT, size}, parent(parent), is_lhs_(is_lhs) {}
void move(int64_t startDiff) override {
if (is_lhs_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the case like a single memory chunk was split into more than 2 pieces? How to interpret lhs, rhs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we only support splitting into 2 pieces.

@ciyongch
Copy link
Contributor

to address the issue of #269

@ciyongch
Copy link
Contributor

Please rebase the PR, and let's have a test with MLP to see the perf gain.

@ciyongch
Copy link
Contributor

ciyongch commented Sep 3, 2024

@yifeizh2 can you help to pull in this patch into MLP pattern to see if any performance gain?

@yifeizh2
Copy link
Contributor

yifeizh2 commented Sep 3, 2024

@yifeizh2 can you help to pull in this patch into MLP pattern to see if any performance gain?

It has performance gain under most of the cases. I will double check the IR later to make confirmation

<style> </style>
dtype batch size hidden list w/o mem merge w/ mem merge Ratio
bf16 128 16x512      
bf16 128 512x256 0.0527 0.0514 102.50%
bf16 128 256x128 0.0354 0.0384 92.16%
bf16 128 512x1024 0.0501 0.0362 138.19%
bf16 128 1024x1024 0.1004 0.0899 111.70%
bf16 128 1024x512 0.1041 0.0950 109.53%
bf16 128 512x256 0.0519 0.0518 100.30%
bf16 128 16x512x256x128      
bf16 128 512x1024x1024x512x256 0.3246 0.3024 107.33%

@ciyongch
Copy link
Contributor

ciyongch commented Sep 4, 2024

@kurapov-peter @ZhennanQin @zhczhong @Yun-Fly Please help to review and approve if no other comments :)

size_t scheduleMemoryAllocations(
const Traces &traces, std::size_t alignment, bool hotFirst,
const InplaceInfoMap &inplaceMap,
std::unordered_map<uintptr_t, std::size_t> &outSchedule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One open here: is there any possible to estimate allocation size during fusion stage, just like what we do in v1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The internals of the pass is already described at the doc. You can skip the IR mutation stage and get the mem scheduling results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline sync, this pass expects memref rather than tensor dialect. So, it needs more changes to be applied during fusion stage.

@ciyongch ciyongch merged commit b34382e into main Sep 9, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

memref-merge optimization Add memory scheduling
8 participants