Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add merge rotation pass #1162

Merged
merged 60 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
0d85eaf
Add pattern
rmoyard Sep 30, 2024
248ab5e
Update structure
rmoyard Sep 30, 2024
9ccb01a
Update
rmoyard Oct 3, 2024
3e5d1d1
Working draft
rmoyard Oct 3, 2024
c4f46ba
Update
rmoyard Oct 3, 2024
c9d56b6
Merge branch 'main' into merge_rotations
rmoyard Oct 7, 2024
7286abc
renamed to `ChainedNamedHermitianOpRewritePattern`
paul0403 Oct 7, 2024
e1f54e5
Add test
rmoyard Oct 7, 2024
4f3f9e2
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 7, 2024
043a5f0
MLIR test: CRY switch qubits
rmoyard Oct 7, 2024
1f87943
add pattern
paul0403 Oct 7, 2024
1458937
preprocess with cse pass so we can check param SSA values;
paul0403 Oct 7, 2024
b8b9913
tests
paul0403 Oct 7, 2024
43fcce7
format
paul0403 Oct 7, 2024
799d96e
test with explicit rotation angles
paul0403 Oct 7, 2024
b60519e
test with different explicit params
paul0403 Oct 7, 2024
84c758c
cano test
rmoyard Oct 7, 2024
e3447ef
Update
rmoyard Oct 7, 2024
a43d9ac
changelog
paul0403 Oct 7, 2024
05f5a5b
Initial draft multiRZ
rmoyard Oct 7, 2024
bc5faaf
Typo
rmoyard Oct 7, 2024
485c6d6
ctrl gates
paul0403 Oct 8, 2024
9dcfdc6
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 8, 2024
90ddcf1
remove template type in parent getter (a value will have just one def…
paul0403 Oct 8, 2024
b9acfe6
factor out a parent gate verifier analysis, so it can be reused with …
paul0403 Oct 8, 2024
13eb094
add all test cases for ctrl
paul0403 Oct 9, 2024
91beb68
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 9, 2024
ebb5169
make the named hermitian pattern use the common analysis as well
paul0403 Oct 9, 2024
5c77fb3
one more test
paul0403 Oct 9, 2024
0823116
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 9, 2024
9ae9ba7
follow include order guideline
paul0403 Oct 9, 2024
738c96a
`verified` --> `succeeded`
paul0403 Oct 9, 2024
e78cca1
move namecheck before wire verification
paul0403 Oct 10, 2024
93b1ed1
Merge remote-tracking branch 'origin/cancel_inverse_adjoint' into mer…
rmoyard Oct 10, 2024
6324fd0
Add analysis integration
rmoyard Oct 10, 2024
a6b8424
MultiRz case
rmoyard Oct 10, 2024
b24703a
Split verifier into a "normal" one and an aggressive one.
paul0403 Oct 10, 2024
2f95838
use aggressive for named gates
paul0403 Oct 10, 2024
d092540
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 10, 2024
9bc658b
add multirz
paul0403 Oct 10, 2024
1fe2622
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 10, 2024
552fae2
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 11, 2024
9a7800a
Merge branch 'cancel_inverse_adjoint' into merge_rotations
rmoyard Oct 11, 2024
0660b68
changelog
paul0403 Oct 11, 2024
a7cb5af
change aggressive name to VerifyParentGateAndNameAnalysis
paul0403 Oct 11, 2024
b54fcb7
Merge branch 'cancel_inverse_adjoint' into merge_rotations
rmoyard Oct 11, 2024
02dc927
changelog grammar
paul0403 Oct 11, 2024
7ba892f
Add multirz test
rmoyard Oct 11, 2024
4fb926f
Merge branch 'cancel_inverse_adjoint' into merge_rotations
rmoyard Oct 11, 2024
4ae4281
Update doc
rmoyard Oct 11, 2024
9cee8c9
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 11, 2024
1b970da
Update mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
rmoyard Oct 11, 2024
ca6df1b
Update
rmoyard Oct 11, 2024
f394838
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 11, 2024
090d285
Merge branch 'main' into merge_rotations
rmoyard Oct 11, 2024
d0789f0
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 11, 2024
bd0ddb6
Update
rmoyard Oct 11, 2024
01cd3f5
Remove erase
rmoyard Oct 11, 2024
def5b3a
Update
rmoyard Oct 11, 2024
0213ae7
Pylint
rmoyard Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@
* Samples on lightning.qubit/kokkos can now be seeded with `qjit(seed=...)`.
[(#1164)](https://github.com/PennyLaneAI/catalyst/pull/1164)

* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitary operations (in addition to the named Hermitian gates).
[(#1186)](https://github.com/PennyLaneAI/catalyst/pull/1186)

<h3>Breaking changes</h3>

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/Quantum/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ std::unique_ptr<mlir::Pass> createAdjointLoweringPass();
std::unique_ptr<mlir::Pass> createRemoveChainedSelfInversePass();
std::unique_ptr<mlir::Pass> createAnnotateFunctionPass();
std::unique_ptr<mlir::Pass> createSplitMultipleTapesPass();
std::unique_ptr<mlir::Pass> createMergeRotationPass();
std::unique_ptr<mlir::Pass> createMergeRotationsPass();
rmoyard marked this conversation as resolved.
Show resolved Hide resolved

} // namespace catalyst
4 changes: 2 additions & 2 deletions mlir/include/Quantum/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def RemoveChainedSelfInversePass : Pass<"remove-chained-self-inverse"> {
let options = QuantumCircuitTransformationPass.options;
}

def MergeRotationPass : Pass<"merge-rotation"> {
def MergeRotationsPass : Pass<"merge-rotations"> {
let summary = "merge rotation boilerplate words";

let constructor = "catalyst::createMergeRotationPass()";
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
let constructor = "catalyst::createMergeRotationsPass()";
let options = !listconcat(
QuantumCircuitTransformationPass.options,
[
Expand Down
1 change: 1 addition & 0 deletions mlir/include/Quantum/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSe
void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
void populateAdjointPatterns(mlir::RewritePatternSet &);
void populateSelfInversePatterns(mlir::RewritePatternSet &);
void populateMergeRotationsPatterns(mlir::RewritePatternSet &);

} // namespace quantum
} // namespace catalyst
3 changes: 2 additions & 1 deletion mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createInlineNestedModulePass);
mlir::registerPass(catalyst::createMemrefCopyToLinalgCopyPass);
mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass);
mlir::registerPass(catalyst::createMergeRotationPass);
mlir::registerPass(catalyst::createMergeRotationsPass);
mlir::registerPass(catalyst::createMitigationLoweringPass);
mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass);
mlir::registerPass(catalyst::createQuantumBufferizationPass);
mlir::registerPass(catalyst::createQuantumConversionPass);
mlir::registerPass(catalyst::createRegisterInactiveCallbackPass);
mlir::registerPass(catalyst::createRemoveChainedSelfInversePass);
mlir::registerPass(catalyst::createMergeRotationsPass);
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
mlir::registerPass(catalyst::createScatterLoweringPass);
mlir::registerPass(catalyst::createSplitMultipleTapesPass);
mlir::registerPass(catalyst::createTestPass);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Quantum/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ file(GLOB SRC
remove_chained_self_inverse.cpp
SplitMultipleTapes.cpp
merge_rotation.cpp
MergeRotationsPatterns.cpp
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
Expand Down
106 changes: 93 additions & 13 deletions mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

#define DEBUG_TYPE "chained-self-inverse"

#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"

#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"

#include "VerifyParentGateAnalysis.hpp"

using llvm::dbgs;
using namespace mlir;
using namespace catalyst;
Expand All @@ -29,7 +33,7 @@ static const mlir::StringSet<> HermitianOps = {"Hadamard", "PauliX", "PauliY", "

namespace {

struct ChainedHadamardOpRewritePattern : public mlir::OpRewritePattern<CustomOp> {
struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern<CustomOp> {
using mlir::OpRewritePattern<CustomOp>::OpRewritePattern;

/// We simplify consecutive Hermitian quantum gates by removing them.
Expand All @@ -41,22 +45,90 @@ struct ChainedHadamardOpRewritePattern : public mlir::OpRewritePattern<CustomOp>
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");

StringRef OpGateName = op.getGateName();
if (!HermitianOps.contains(OpGateName))
if (!HermitianOps.contains(OpGateName)) {
return failure();
}

VerifyParentGateAndNameAnalysis<CustomOp> vpga(op);
if (!vpga.getVerifierResult()) {
return failure();
}

// Replace uses
ValueRange InQubits = op.getInQubits();
auto ParentOp = dyn_cast_or_null<CustomOp>(InQubits[0].getDefiningOp());
if (!ParentOp || ParentOp.getGateName() != OpGateName)
ValueRange simplifiedVal = ParentOp.getInQubits();
rewriter.replaceOp(op, simplifiedVal);
return success();
}
};

template <typename OpType>
struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern<OpType> {
using mlir::OpRewritePattern<OpType>::OpRewritePattern;

bool verifyParentGateParams(OpType op, OpType parentOp) const
{
// Verify that the parent gate has the same parameters
ValueRange opParams = op.getAllParams();
ValueRange parentOpParams = parentOp.getAllParams();

if (opParams.size() != parentOpParams.size()) {
return false;
}

for (auto [opParam, parentOpParam] : llvm::zip(opParams, parentOpParams)) {
if (opParam != parentOpParam) {
return false;
}
}

return true;
}

bool verifyOneAdjoint(OpType op, OpType parentOp) const
{
// Verify that exactly one of the neighbouring pair is an adjoint
bool opIsAdj = op->hasAttr("adjoint");
bool parentIsAdj = parentOp->hasAttr("adjoint");
return opIsAdj != parentIsAdj; // "XOR" to check just one true
}

/// Remove generic neighbouring gate pairs of the form
/// --- gate --- gate{adjoint} ---
/// Conditions:
/// 1. Parent gate verification must pass. See VerifyParentGateAnalysis.hpp.
/// 2. If there are parameters, both gate must have the same parameters.
/// [This pattern assumes the IR is already processed by CSE]
mlir::LogicalResult matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override
{
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");

VerifyParentGateAndNameAnalysis<OpType> vpga(op);
if (!vpga.getVerifierResult()) {
return failure();
}

ValueRange InQubits = op.getInQubits();
auto parentOp = dyn_cast_or_null<OpType>(InQubits[0].getDefiningOp());

if (!verifyParentGateParams(op, parentOp)) {
return failure();
}

ValueRange ParentOutQubits = ParentOp.getOutQubits();
// Check if the input qubits to the current operation match the output qubits of the parent.
for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) {
if (Qubit.getDefiningOp<CustomOp>() != ParentOp || Qubit != ParentOutQubits[Idx])
return failure();
if (!verifyOneAdjoint(op, parentOp)) {
return failure();
}

// Replace uses
ValueRange originalNonCtrlQubits = parentOp.getNonCtrlQubitOperands();
ValueRange originalCtrlQubits = parentOp.getCtrlQubitOperands();
for (const auto &[idx, nonCtrlQubitResult] : llvm::enumerate(op.getNonCtrlQubitResults())) {
nonCtrlQubitResult.replaceAllUsesWith(originalNonCtrlQubits[idx]);
}
for (const auto &[idx, ctrlQubitResult] : llvm::enumerate(op.getCtrlQubitResults())) {
ctrlQubitResult.replaceAllUsesWith(originalCtrlQubits[idx]);
}
ValueRange simplifiedVal = ParentOp.getInQubits();
rewriter.replaceOp(op, simplifiedVal);
return success();
}
};
Expand All @@ -68,7 +140,15 @@ namespace quantum {

void populateSelfInversePatterns(RewritePatternSet &patterns)
{
patterns.add<ChainedHadamardOpRewritePattern>(patterns.getContext(), 1);
patterns.add<ChainedNamedHermitianOpRewritePattern>(patterns.getContext(), 1);

// TODO: better organize the quantum dialect
// There is an interface `QuantumGate` for all the unitary gate operations,
// but interfaces cannot be accepted by pattern matchers, since pattern
// matchers require the target operations to have concrete names in the IR.
patterns.add<ChainedUUadjOpRewritePattern<CustomOp>>(patterns.getContext(), 1);
patterns.add<ChainedUUadjOpRewritePattern<QubitUnitaryOp>>(patterns.getContext(), 1);
patterns.add<ChainedUUadjOpRewritePattern<MultiRZOp>>(patterns.getContext(), 1);
}

} // namespace quantum
Expand Down
131 changes: 131 additions & 0 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "merge-rotations"

#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
#include "VerifyParentGateAnalysis.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"

using llvm::dbgs;
using namespace mlir;
using namespace catalyst::quantum;

static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift", "Rot",
"CRX", "CRY", "CRZ", "ControlledPhaseShift", "CRot"};

namespace {

struct MergeRotationsRewritePattern : public mlir::OpRewritePattern<CustomOp> {
using mlir::OpRewritePattern<CustomOp>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(CustomOp op, mlir::PatternRewriter &rewriter) const override
{
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");
auto loc = op.getLoc();
StringRef opGateName = op.getGateName();

if (!rotationsSet.contains(opGateName))
return failure();
ValueRange inQubits = op.getInQubits();
auto parentOp = dyn_cast_or_null<CustomOp>(inQubits[0].getDefiningOp());

VerifyParentGateAndNameAnalysis<CustomOp> vpga(op);
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
if (!vpga.getVerifierResult()) {
return failure();
}

TypeRange outQubitsTypes = op.getOutQubits().getTypes();
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
ValueRange parentInQubits = parentOp.getInQubits();
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();

auto parentParams = parentOp.getParams();
auto params = op.getParams();
SmallVector<mlir::Value> sumParams;
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
mlir::Value sumParam =
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
sumParams.push_back(sumParam);
};
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
parentInQubits, opGateName, nullptr,
parentInCtrlQubits, parentInCtrlValues);

op.replaceAllUsesWith(mergeOp);
op.erase();
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
parentOp.erase();

return success();
}
};

struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern<MultiRZOp> {
using mlir::OpRewritePattern<MultiRZOp>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(MultiRZOp op,
mlir::PatternRewriter &rewriter) const override
{
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");
auto loc = op.getLoc();

VerifyParentGateAnalysis<MultiRZOp> vpga(op);
if (!vpga.getVerifierResult()) {
return failure();
}

ValueRange inQubits = op.getInQubits();
auto parentOp = dyn_cast_or_null<MultiRZOp>(inQubits[0].getDefiningOp());
if (!parentOp)
return failure();

TypeRange outQubitsTypes = op.getOutQubits().getTypes();
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
ValueRange parentInQubits = parentOp.getInQubits();
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();

auto parentTheta = parentOp.getTheta();
auto theta = op.getTheta();

mlir::Value sumParam = rewriter.create<arith::AddFOp>(loc, parentTheta, theta).getResult();

auto mergeOp = rewriter.create<MultiRZOp>(loc, outQubitsTypes, outQubitsCtrlTypes,
sumParam, parentInQubits, nullptr,
parentInCtrlQubits, parentInCtrlValues);
op.replaceAllUsesWith(mergeOp);
op.erase();
parentOp.erase();

return success();
}
};
} // namespace

namespace catalyst {
namespace quantum {

void populateMergeRotationsPatterns(RewritePatternSet &patterns)
{
patterns.add<MergeRotationsRewritePattern>(patterns.getContext(), 1);
patterns.add<MergeMultiRZRewritePattern>(patterns.getContext(), 1);
}

} // namespace quantum
} // namespace catalyst
Loading
Loading