Skip to content

Commit

Permalink
Merge branch 'main' into raultorres/measures_and_dshapearrays
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc authored Oct 21, 2024
2 parents 0a4c25b + 05da8ac commit 21f5cc3
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 42 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ doc/code/api

# Development
venv

# Cache files
.cache
18 changes: 15 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

* A peephole merge rotations pass is now available in MLIR. It can be added to `catalyst.passes.pipeline`, or the
* A peephole merge rotations pass is now available in MLIR. It can be added to `catalyst.passes.pipeline`, or the
Python function `catalyst.passes.merge_rotations` can be directly called on a `QNode`.
[(#1162)](https://github.com/PennyLaneAI/catalyst/pull/1162)
[(#1206)](https://github.com/PennyLaneAI/catalyst/pull/1206)
Expand All @@ -144,7 +144,7 @@

```python
from catalys.passes import merge_rotations

@qjit
@merge_rotations
@qml.qnode(qml.device("lightning.qubit", wires=1))
Expand Down Expand Up @@ -187,6 +187,9 @@

<h3>Improvements</h3>

* Implement a Catalyst runtime plugin that mocks out all functions in the QuantumDevice interface.
[(#1179)](https://github.com/PennyLaneAI/catalyst/pull/1179)

* Scalar tensors are eliminated from control flow operations in the program, and are replaced with
bare scalars instead. This improves compilation time and memory usage at runtime by avoiding heap
allocations and reducing the amount of instructions.
Expand Down Expand Up @@ -249,6 +252,7 @@

* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitary operations (in addition to the named Hermitian gates).
[(#1186)](https://github.com/PennyLaneAI/catalyst/pull/1186)
[(#1211)](https://github.com/PennyLaneAI/catalyst/pull/1211)

<h3>Breaking changes</h3>

Expand All @@ -270,7 +274,11 @@

<h3>Bug fixes</h3>

* Resolve a bug where `mitigate_with_zne` does not work properly with shots and devices
* Fix a bug preventing the target of `qml.adjoint` and `qml.ctrl` calls from being transformed by
AutoGraph.
[(#1212)](https://github.com/PennyLaneAI/catalyst/pull/1212)

* Resolve a bug where `mitigate_with_zne` does not work properly with shots and devices
supporting only Counts and Samples (e.g. Qrack). (transform: `measurements_from_sample`).
[(#1165)](https://github.com/PennyLaneAI/catalyst/pull/1165)

Expand All @@ -283,6 +291,9 @@
* Fixes taking gradient of nested accelerate callbacks.
[(#1156)](https://github.com/PennyLaneAI/catalyst/pull/1156)

* Registers the func dialect as a requirement for running the scatter lowering pass.
[(#1216)](https://github.com/PennyLaneAI/catalyst/pull/1216)

<h3>Internal changes</h3>

* Remove deprecated pennylane code across the frontend.
Expand Down Expand Up @@ -341,6 +352,7 @@

This release contains contributions from (in alphabetical order):

Amintor Dusko,
Joey Carter,
Spencer Comin,
Lillian M.A. Frederiksen,
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.9.0-dev36"
__version__ = "0.9.0-dev40"
2 changes: 2 additions & 0 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None):
# HOTFIX: pass through calls of known Catalyst wrapper functions
if fn in (
catalyst.adjoint,
qml.adjoint,
catalyst.ctrl,
qml.ctrl,
catalyst.grad,
catalyst.value_and_grad,
catalyst.jacobian,
Expand Down
30 changes: 7 additions & 23 deletions frontend/test/pytest/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,7 @@
from jax.errors import TracerBoolConversionError
from numpy.testing import assert_allclose

from catalyst import (
AutoGraphError,
adjoint,
autograph_source,
cond,
ctrl,
debug,
disable_autograph,
for_loop,
grad,
jacobian,
jvp,
measure,
qjit,
run_autograph,
vjp,
vmap,
while_loop,
)
from catalyst import *
from catalyst.autograph.transformer import TRANSFORMER
from catalyst.utils.dummy import dummy_func
from catalyst.utils.exceptions import CompileError
Expand Down Expand Up @@ -295,7 +277,8 @@ def fn(x: float):
assert check_cache(inner.user_function.func)
assert fn(np.pi) == -1

def test_adjoint_wrapper(self):
@pytest.mark.parametrize("adjoint_fn", [adjoint, qml.adjoint])
def test_adjoint_wrapper(self, adjoint_fn):
"""Test conversion is happening succesfully on functions wrapped with 'adjoint'."""

def inner(x):
Expand All @@ -304,14 +287,15 @@ def inner(x):
@qjit(autograph=True)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def fn(x: float):
adjoint(inner)(x)
adjoint_fn(inner)(x)
return qml.probs()

assert hasattr(fn.user_function, "ag_unconverted")
assert check_cache(inner)
assert np.allclose(fn(np.pi), [0.0, 1.0])

def test_ctrl_wrapper(self):
@pytest.mark.parametrize("ctrl_fn", [ctrl, qml.ctrl])
def test_ctrl_wrapper(self, ctrl_fn):
"""Test conversion is happening succesfully on functions wrapped with 'ctrl'."""

def inner(x):
Expand All @@ -320,7 +304,7 @@ def inner(x):
@qjit(autograph=True)
@qml.qnode(qml.device("lightning.qubit", wires=2))
def fn(x: float):
ctrl(inner, control=1)(x)
ctrl_fn(inner, control=1)(x)
return qml.probs()

assert hasattr(fn.user_function, "ag_unconverted")
Expand Down
1 change: 1 addition & 0 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def ScatterLoweringPass : Pass<"scatter-lowering"> {
let summary = "Lower scatter op from Stable HLO to loops.";

let dependentDialects = [
"mlir::func::FuncDialect",
"index::IndexDialect",
"mhlo::MhloDialect",
"scf::SCFDialect"
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Catalyst/Transforms/scatter_lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -57,4 +58,4 @@ std::unique_ptr<Pass> createScatterLoweringPass()
return std::make_unique<ScatterLoweringPass>();
}

} // namespace catalyst
} // namespace catalyst
10 changes: 7 additions & 3 deletions mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern<Cus

// Replace uses
ValueRange InQubits = op.getInQubits();
auto ParentOp = dyn_cast_or_null<CustomOp>(InQubits[0].getDefiningOp());
ValueRange simplifiedVal = ParentOp.getInQubits();
rewriter.replaceOp(op, simplifiedVal);
auto parentOp = cast<CustomOp>(InQubits[0].getDefiningOp());

// TODO: it would make more sense for getQubitOperands()
// to return ValueRange, like the other getters
std::vector<mlir::Value> originalQubits = parentOp.getQubitOperands();

rewriter.replaceOp(op, originalQubits);
return success();
}
};
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Quantum/ChainedSelfInverseTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,56 @@ func.func @test_chained_self_inverse(%arg0: f64) -> (!quantum.bit, !quantum.bit,
// CHECK: quantum.multirz
return %mrz_out#0, %mrz_out#1, %mrz_out#2 : !quantum.bit, !quantum.bit, !quantum.bit
}

// -----


// test with matched control wires on named Hermitian gate

// CHECK-LABEL: test_chained_self_inverse
func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit) {
%true = llvm.mlir.constant (1 : i1) :i1
%false = llvm.mlir.constant (0 : i1) :i1

// CHECK: quantum.alloc
// CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0]
// CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1]
// CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2]
%reg = quantum.alloc( 3) : !quantum.reg
%0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit
%1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit

%out_qubits, %out_ctrl_qubits:2 = quantum.custom "Hadamard"() %0 ctrls(%1, %2) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit
%out_qubits_1, %out_ctrl_qubits_1:2 = quantum.custom "Hadamard"() %out_qubits ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit

// CHECK-NOT: quantum.custom
// CHECK: return [[IN0]], [[IN1]], [[IN2]]
return %out_qubits_1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit
}

// -----


// test with unmatched control wires on named Hermitian gate

// CHECK-LABEL: test_chained_self_inverse
func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit) {
%true = llvm.mlir.constant (1 : i1) :i1
%false = llvm.mlir.constant (0 : i1) :i1

// CHECK: quantum.alloc
// CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0]
// CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1]
// CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2]
%reg = quantum.alloc( 3) : !quantum.reg
%0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit
%1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit

%out_qubits, %out_ctrl_qubits:2 = quantum.custom "Hadamard"() %0 ctrls(%1, %2) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit
%out_qubits_1, %out_ctrl_qubits_1:2 = quantum.custom "Hadamard"() %out_qubits ctrls(%out_ctrl_qubits#1, %out_ctrl_qubits#0) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit

// CHECK: quantum.custom "Hadamard"
return %out_qubits_1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit
}
8 changes: 5 additions & 3 deletions runtime/lib/backend/null_qubit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ add_library(rtd_null_qubit SHARED NullQubit.cpp)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

target_include_directories(rtd_null_qubit PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}
${runtime_includes}
)
target_include_directories(rtd_null_qubit PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
${runtime_includes}
${backend_includes}
)

set_property(TARGET rtd_null_qubit PROPERTY POSITION_INDEPENDENT_CODE ON)
35 changes: 30 additions & 5 deletions runtime/lib/backend/null_qubit/NullQubit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <algorithm> // generate_n
#include <complex>
#include <memory>
#include <optional>
Expand All @@ -22,6 +23,7 @@

#include "DataView.hpp"
#include "QuantumDevice.hpp"
#include "QubitManager.hpp"
#include "Types.h"

namespace Catalyst::Runtime::Devices {
Expand Down Expand Up @@ -51,7 +53,11 @@ struct NullQubit final : public Catalyst::Runtime::QuantumDevice {
*
* @return `QubitIdType`
*/
auto AllocateQubit() -> QubitIdType { return 0; }
auto AllocateQubit() -> QubitIdType
{
num_qubits_++; // next_id
return this->qubit_manager.Allocate(num_qubits_);
}

/**
* @brief Allocate a vector of qubits.
Expand All @@ -62,25 +68,40 @@ struct NullQubit final : public Catalyst::Runtime::QuantumDevice {
*/
auto AllocateQubits(size_t num_qubits) -> std::vector<QubitIdType>
{
return std::vector<QubitIdType>(num_qubits, 0);
if (!num_qubits) {
return {};
}
std::vector<QubitIdType> result(num_qubits);
std::generate_n(result.begin(), num_qubits, [this]() { return AllocateQubit(); });
return result;
}

/**
* @brief Doesn't Release a qubit.
*/
void ReleaseQubit(QubitIdType) {}
void ReleaseQubit(QubitIdType q)
{
if (!num_qubits_) {
num_qubits_--;
this->qubit_manager.Release(q);
}
}

/**
* @brief Doesn't Release all qubits.
*/
void ReleaseAllQubits() {}
void ReleaseAllQubits()
{
num_qubits_ = 0;
this->qubit_manager.ReleaseAll();
}

/**
* @brief Doesn't Get the number of allocated qubits.
*
* @return `size_t`
*/
[[nodiscard]] auto GetNumQubits() const -> size_t { return 0; }
[[nodiscard]] auto GetNumQubits() const -> size_t { return num_qubits_; }

/**
* @brief Doesn't Set the number of device shots.
Expand Down Expand Up @@ -295,5 +316,9 @@ struct NullQubit final : public Catalyst::Runtime::QuantumDevice {
{
return {0, 0, 0, {}, {}};
}

private:
std::size_t num_qubits_{0};
Catalyst::Runtime::QubitManager<QubitIdType, size_t> qubit_manager{};
};
} // namespace Catalyst::Runtime::Devices
Loading

0 comments on commit 21f5cc3

Please sign in to comment.