Skip to content

Conversation

erick-xanadu
Copy link
Contributor

@erick-xanadu erick-xanadu commented Apr 24, 2025

Context:
This work is based on #1027

As part of the mlir update, the bufferization of the custom catalyst dialects need to migrate to the new one-shot bufferization interface, as opposed to the old pattern-rewrite style bufferization passes. See more context in #1027.

As an example, here is how the new bufferization interface is used for mlir's core arith dialect:
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h
https://github.com/llvm/llvm-project/blob/7ee0097b486b31be8b9a1750b2cd47580efd9587/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp#L54

Description of the Change:
On the current mlir commit we track, both the old and new bufferization styles exist. The old pattern rewrite style is deprecated.

To ease the workflow organization, we migrate one dialect at a time. This PR migrates the Quantum dialect's bufferization to the new one-shot interface.

Note that the new one-shot interface is supposed to be called only once in the pipeline. However, because we haven't migrated all the dialects yet, we simply swap out the old --quantum--bufferize pass in-place, with the new one-shot bufferization pass running on the quantum dialect only.

Benefits:
Align with mlir practices; one step closer to updating mlir.

[sc-71487]

@paul0403 paul0403 changed the title Eochoa/2025 04 16/updating to0jaxv0.4.29.0 Migrate to new one-shot bufferization in mlir Apr 28, 2025
@paul0403 paul0403 changed the title Migrate to new one-shot bufferization in mlir [WIP] Migrate to new one-shot bufferization in mlir Apr 28, 2025
Copy link
Contributor Author

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome Paul!

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work 💯 It was a good idea to chunk it up into different PRs, much easier to review that way :)

One last thing, can we run the new pathway on some large-ish example code and see if the MLIR changes meaningful (e.g. more copies or allocations than before)? Alternatively we also could patch in some quick instrumentation into the runtime alloc/free functions to check if their call stats have changed in any way.

We won't have to do this on all dialect PRs, but at least once would be good for us to understand potential compilation differences :)

@paul0403
Copy link
Member

paul0403 commented Apr 30, 2025

can we run the new pathway on some large-ish example code and see if the MLIR changes meaningful

Yes, I plan to run the benchmark suite after these big-ish version updates. Right now there's still no support to request a benchmark run through catalyst PR comment section directly yet, so I've just been giving @jzaia18 the commit I want to run manually (I was doing this for the jax update already). I can also quickly patch the allocate/free runtime stubs just to log some counts.

I expect there won't be any major differences, since we still generate the same mlir. However, there indeed could be memory/runtime differences inside the bufferization pass itself. We will have to see what happens with the benchmarks!

@dime10
Copy link
Contributor

dime10 commented May 1, 2025

Yes, I plan to run the benchmark suite after these big-ish version updates.

Benchmark suite is one thing, but it only spit out numbers for the overall execution, it might not help us understand what differences are exactly (if any) in how programs are bufferized.

@dime10
Copy link
Contributor

dime10 commented May 1, 2025

Yes, I plan to run the benchmark suite after these big-ish version updates.

Benchmark suite is one thing, but it only spit out numbers for the overall execution, it might not help us understand what differences are exactly (if any) in how programs are bufferized.

Actually, looking at the ops that we're bufferizing here there really shouldn't be any difference, since already always allocate new buffers for the measurement ops and only read in the matrix ops, so we should be good :) We should revisit this though once we switch over the func bufferization and the buffer deallocation.

paul0403 and others added 4 commits May 1, 2025 14:49
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
These involve replacing tensor results with new memref allocations.
Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great 🍓

@paul0403 paul0403 merged commit e06eba4 into main May 1, 2025
43 checks passed
@paul0403 paul0403 deleted the eochoa/2025-04-16/updating-to0jaxv0.4.29.0 branch May 1, 2025 20:23
paul0403 added a commit that referenced this pull request May 12, 2025
**Context:**
This work is based on #1027 .

As part of the mlir update, the bufferization of the custom catalyst
dialects need to migrate to the new one-shot bufferization interface, as
opposed to the old pattern-rewrite style bufferization passes.
See more context in #1027.

The `Quantum` dialect was migrated in #1686 .

**Description of the Change:**
MIgrate `Catalyst` dialect to one-shot bufferization.

**Benefits:**
Align with mlir practices; one step closer to updating mlir.

[sc-71487]

---------

Co-authored-by: Tzung-Han Juang <tzunghan.juang@gmail.com>
paul0403 added a commit that referenced this pull request May 23, 2025
**Context:**
This work is based on #1027 .

As part of the mlir update, the bufferization of the custom catalyst
dialects need to migrate to the new one-shot bufferization interface, as
opposed to the old pattern-rewrite style bufferization passes.
See more context in #1027.

The `Quantum` dialect was migrated in #1686 .
The `Catalyst` dialect was migrated in #1708 .

Note that #1139 refactors the gradient dialect's bufferization into
preprocess, bufferization, and postprocess.
Only the middle bufferization stage is supposed to be replaced by
one-shot bufferization.

**Description of the Change:**
Migrate `Gradient` dialect to one-shot bufferization.

**Benefits:**
Align with mlir practices; one step closer to updating mlir.

[sc-71487]

---------

Co-authored-by: Tzung-Han Juang <tzunghan.juang@gmail.com>
paul0403 added a commit that referenced this pull request May 26, 2025
…1751)

**Context:**
This work is based on #1027.

Now that we have migrated all the individual dialects, we should migrate
the entire bufferization pipeline.

The `Quantum` dialect was migrated in
#1686 .
The `Catalyst` dialect was migrated in
#1708 .
The `Gradient` dialect was migrated in
#1740 .

See more context in #1027. 

Upstream changes in llvm were required for this bufferization update. As
a result, the llvm version and mlir-hlo version were updated to
```
mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4
llvm=5f74671c85877e03622e8d308aee15ed73ccee7c
```
These are the versions tracked by jax 0.4.32. 
These are the earliest jax-tagged versions with complete upstream
bufferization changes.

**Related GitHub Issues:**
[sc-71487]

---------

Co-authored-by: Tzung-Han Juang <tzunghan.juang@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants