Skip to content

Commit

Permalink
Create a generalized uniform superposition state gate (#6506)
Browse files Browse the repository at this point in the history
* Create generalized_uniform_superposition_gate.py

 Creates a generalized uniform superposition state, $\frac{1}{\sqrt{M}} \sum_{j=0}^{M-1}  \ket{j} $ (where 1< M <= 2^n), 
    using n qubits, according to the Shukla-Vedula algorithm [SV24].

    Note: The Shukla-Vedula algorithm [SV24] offers an efficient approach for creation of a generalized uniform superposition 
    state of the form, $\frac{1}{\sqrt{M}} \sum_{j=0}^{M-1}  \ket{j} $, requiring only $O(log_2 (M))$ qubits and $O(log_2 (M))$ 
    gates. This provides an exponential improvement (in the context of reduced resources and complexity) over other approaches in the literature.

Reference:
[SV24] A. Shukla and P. Vedula, “An efficient quantum algorithm for preparation of uniform quantum superposition states,” 
 Quantum Information Processing, 23(38): pp. 1-32 (2024).
  • Loading branch information
prag16 authored May 21, 2024
1 parent df07e94 commit aa04196
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 0 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@
ZPowGate,
ZZ,
ZZPowGate,
UniformSuperpositionGate,
)

from cirq.transformers import (
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def _symmetricalqidpair(qids):
'ZipLongest': cirq.ZipLongest,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
'UniformSuperpositionGate': cirq.UniformSuperpositionGate,
# Old types, only supported for backwards-compatibility
'BooleanHamiltonian': _boolean_hamiltonian_gate_op, # Removed in v0.15
'CrossEntropyResult': _cross_entropy_result, # Removed in v0.16
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,5 @@
from cirq.ops.state_preparation_channel import StatePreparationChannel

from cirq.ops.control_values import AbstractControlValues, ProductOfSums, SumOfProducts

from cirq.ops.uniform_superposition_gate import UniformSuperpositionGate
123 changes: 123 additions & 0 deletions cirq-core/cirq/ops/uniform_superposition_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Any, Dict, TYPE_CHECKING

import numpy as np
from cirq.ops.common_gates import H, ry
from cirq.ops.pauli_gates import X
from cirq.ops import raw_types


if TYPE_CHECKING:
import cirq


class UniformSuperpositionGate(raw_types.Gate):
r"""Creates a uniform superposition state on the states $[0, M)$
The gate creates the state $\frac{1}{\sqrt{M}}\sum_{j=0}^{M-1}\ket{j}$
(where $1\leq M \leq 2^n$), using n qubits, according to the Shukla-Vedula algorithm [SV24].
References:
[SV24]
[An efficient quantum algorithm for preparation of uniform quantum superposition
states](https://arxiv.org/abs/2306.11747)
"""

def __init__(self, m_value: int, num_qubits: int) -> None:
"""Initializes UniformSuperpositionGate.
Args:
m_value: The number of computational basis states.
num_qubits: The number of qubits used.
Raises:
ValueError: If `m_value` is not a positive integer, or
if `num_qubits` is not an integer greater than or equal to log2(m_value).
"""
if not (isinstance(m_value, int) and (m_value > 0)):
raise ValueError("m_value must be a positive integer.")
log_two_m_value = m_value.bit_length()

if (m_value & (m_value - 1)) == 0:
log_two_m_value = log_two_m_value - 1
if not (isinstance(num_qubits, int) and (num_qubits >= log_two_m_value)):
raise ValueError(
"num_qubits must be an integer greater than or equal to log2(m_value)."
)
self._m_value = m_value
self._num_qubits = num_qubits

def _decompose_(self, qubits: Sequence["cirq.Qid"]) -> "cirq.OP_TREE":
"""Decomposes the gate into a sequence of standard gates.
Implements the construction from https://arxiv.org/pdf/2306.11747.
"""
qreg = list(qubits)
qreg.reverse()

if self._m_value == 1: # if m_value is 1, do nothing
return
if (self._m_value & (self._m_value - 1)) == 0: # if m_value is an integer power of 2
m = self._m_value.bit_length() - 1
yield H.on_each(qreg[:m])
return
k = self._m_value.bit_length()
l_value = []
for i in range(self._m_value.bit_length()):
if (self._m_value >> i) & 1:
l_value.append(i) # Locations of '1's

yield X.on_each(qreg[q_bit] for q_bit in l_value[1:k])
m_current = 2 ** (l_value[0])
theta = -2 * np.arccos(np.sqrt(m_current / self._m_value))
if l_value[0] > 0: # if m_value is even
yield H.on_each(qreg[: l_value[0]])

yield ry(theta).on(qreg[l_value[1]])

for i in range(l_value[0], l_value[1]):
yield H(qreg[i]).controlled_by(qreg[l_value[1]], control_values=[False])

for m in range(1, len(l_value) - 1):
theta = -2 * np.arccos(np.sqrt(2 ** l_value[m] / (self._m_value - m_current)))
yield ry(theta).on(qreg[l_value[m + 1]]).controlled_by(
qreg[l_value[m]], control_values=[0]
)
for i in range(l_value[m], l_value[m + 1]):
yield H.on(qreg[i]).controlled_by(qreg[l_value[m + 1]], control_values=[0])

m_current = m_current + 2 ** (l_value[m])

def num_qubits(self) -> int:
return self._num_qubits

@property
def m_value(self) -> int:
return self._m_value

def __eq__(self, other):
if isinstance(other, UniformSuperpositionGate):
return (self._m_value == other._m_value) and (self._num_qubits == other._num_qubits)
return False

def __repr__(self) -> str:
return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})'

def _json_dict_(self) -> Dict[str, Any]:
d = {}
d['m_value'] = self._m_value
d['num_qubits'] = self._num_qubits
return d

def __str__(self) -> str:
return f'UniformSuperpositionGate(m_value={self._m_value}, num_qubits={self._num_qubits})'
94 changes: 94 additions & 0 deletions cirq-core/cirq/ops/uniform_superposition_gate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import cirq


@pytest.mark.parametrize(
["m", "n"],
[[int(m), n] for n in range(3, 7) for m in np.random.randint(1, 1 << n, size=3)]
+ [(1, 2), (4, 2), (6, 3), (7, 3)],
)
def test_generated_unitary_is_uniform(m: int, n: int) -> None:
r"""The code checks that the unitary matrix corresponds to the generated uniform superposition
states (see uniform_superposition_gate.py). It is enough to check that the
first colum of the unitary matrix (which corresponds to the action of the gate on
$\ket{0}^n$ is $\frac{1}{\sqrt{M}} [1 1 \cdots 1 0 \cdots 0]^T$, where the first $M$
entries are all "1"s (excluding the normalization factor of $\frac{1}{\sqrt{M}}$ and the
remaining $2^n-M$ entries are all "0"s.
"""
gate = cirq.UniformSuperpositionGate(m, n)
matrix = np.array(cirq.unitary(gate))
np.testing.assert_allclose(
matrix[:, 0], (1 / np.sqrt(m)) * np.array([1] * m + [0] * (2**n - m)), atol=1e-8
)


@pytest.mark.parametrize(["m", "n"], [(1, 1), (-2, 1), (-3.1, 2), (6, -4), (5, 6.1)])
def test_incompatible_m_value_and_qubit_args(m: int, n: int) -> None:
r"""The code checks that test errors are raised if the arguments m (number of
superposition states and n (number of qubits) are positive integers and are compatible
(i.e., n >= log2(m)).
"""

if not (isinstance(m, int)):
with pytest.raises(ValueError, match="m_value must be a positive integer."):
cirq.UniformSuperpositionGate(m, n)
elif not (isinstance(n, int)):
with pytest.raises(
ValueError,
match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).",
):
cirq.UniformSuperpositionGate(m, n)
elif m < 1:
with pytest.raises(ValueError, match="m_value must be a positive integer."):
cirq.UniformSuperpositionGate(int(m), int(n))
elif n < np.log2(m):
with pytest.raises(
ValueError,
match="num_qubits must be an integer greater than or equal to log2\\(m_value\\).",
):
cirq.UniformSuperpositionGate(m, n)


def test_repr():
assert (
repr(cirq.UniformSuperpositionGate(7, 3))
== 'UniformSuperpositionGate(m_value=7, num_qubits=3)'
)


def test_uniform_superposition_gate_json_dict():
assert cirq.UniformSuperpositionGate(7, 3)._json_dict_() == {'m_value': 7, 'num_qubits': 3}


def test_str():
assert (
str(cirq.UniformSuperpositionGate(7, 3))
== 'UniformSuperpositionGate(m_value=7, num_qubits=3)'
)


@pytest.mark.parametrize(["m", "n"], [(5, 3), (10, 4)])
def test_eq(m: int, n: int) -> None:
a = cirq.UniformSuperpositionGate(m, n)
b = cirq.UniformSuperpositionGate(m, n)
c = cirq.UniformSuperpositionGate(m + 1, n)
d = cirq.X
assert a.m_value == b.m_value
assert a.__eq__(b)
assert not (a.__eq__(c))
assert not (a.__eq__(d))
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"cirq_type": "UniformSuperpositionGate",
"m_value": 7,
"num_qubits": 3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.UniformSuperpositionGate(m_value=7, num_qubits=3)

0 comments on commit aa04196

Please sign in to comment.