-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tensor from classical action (#1514)
* WIP tensor from classical action * fix errors * fix syntax error? * move test, add timeout * rename * notebooks * cleanup * `my_tensors` from classical * rename and docs * more tests * imports + doc * nits * fix docstring * make files private (matching others) --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
- Loading branch information
1 parent
d92bb1e
commit 7bbec3b
Showing
7 changed files
with
212 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,7 +86,7 @@ | |
}, | ||
"outputs": [], | ||
"source": [ | ||
"comparator = Comparator(7)" | ||
"comparator = Comparator(3)" | ||
] | ||
}, | ||
{ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import itertools | ||
from typing import Iterable, TYPE_CHECKING | ||
|
||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
if TYPE_CHECKING: | ||
import quimb.tensor as qtn | ||
|
||
from qualtran import Bloq, ConnectionT, Register | ||
from qualtran.simulation.classical_sim import ClassicalValT | ||
|
||
|
||
def _bits_to_classical_reg_data(reg: 'Register', bits: NDArray[np.uint8]) -> 'ClassicalValT': | ||
if reg.shape == (): | ||
return reg.dtype.from_bits([*bits.flat]) | ||
return reg.dtype.from_bits_array(np.reshape(bits, reg.shape + (reg.dtype.num_qubits,))) | ||
|
||
|
||
def _bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray: | ||
"""Internal method to compute the tensor of a bloq using its classical action. | ||
Args: | ||
bloq: the Bloq | ||
Returns: | ||
an NDArray of shape (2, 2, ...) indexed by the output bits followed by input bits. | ||
""" | ||
left_qubit_counts = tuple(reg.total_bits() for reg in bloq.signature.lefts()) | ||
left_qubit_splits = np.cumsum(left_qubit_counts) | ||
|
||
n_qubits_left = sum(left_qubit_counts) | ||
n_qubits_right = sum(reg.total_bits() for reg in bloq.signature.rights()) | ||
|
||
if n_qubits_left + n_qubits_right > 40: | ||
raise ValueError(f"tensor is too large: {n_qubits_left + n_qubits_right} total qubits") | ||
|
||
matrix = np.zeros((2,) * (n_qubits_right + n_qubits_left)) | ||
|
||
for input_t in itertools.product((0, 1), repeat=n_qubits_left): | ||
*inputs_t, last = np.split(input_t, left_qubit_splits) | ||
assert np.size(last) == 0 | ||
|
||
input_kwargs = { | ||
reg.name: _bits_to_classical_reg_data(reg, bits) | ||
for reg, bits in zip(bloq.signature.lefts(), inputs_t) | ||
} | ||
output_args = bloq.call_classically(**input_kwargs) | ||
|
||
if output_args: | ||
output_t = np.concatenate( | ||
[ | ||
reg.dtype.to_bits_array(np.asarray(vals)).flat | ||
for reg, vals in zip(bloq.signature.rights(), output_args) | ||
] | ||
) | ||
else: | ||
output_t = np.array([]) | ||
|
||
matrix[tuple([*np.atleast_1d(output_t), *np.atleast_1d(input_t)])] = 1 | ||
|
||
return matrix | ||
|
||
|
||
def bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray: | ||
"""Return a contracted, dense ndarray representing the bloq, using its classical action. | ||
Args: | ||
bloq: The bloq | ||
Raises: | ||
ValueError: if the bloq does not have a classical action. | ||
""" | ||
try: | ||
matrix = _bloq_to_dense_via_classical_action(bloq) | ||
except ValueError as e: | ||
raise ValueError(f"cannot compute tensor for {bloq}: {str(e)}") from e | ||
|
||
n_qubits_left = sum(reg.total_bits() for reg in bloq.signature.lefts()) | ||
n_qubits_right = sum(reg.total_bits() for reg in bloq.signature.rights()) | ||
|
||
shape: tuple[int, ...] | ||
if n_qubits_left == 0 and n_qubits_right == 0: | ||
shape = () | ||
elif n_qubits_left == 0 or n_qubits_right == 0: | ||
shape = (2 ** max(n_qubits_left, n_qubits_right),) | ||
else: | ||
shape = (2**n_qubits_right, 2**n_qubits_left) | ||
|
||
return matrix.reshape(shape) | ||
|
||
|
||
def my_tensors_from_classical_action( | ||
bloq: 'Bloq', incoming: dict[str, 'ConnectionT'], outgoing: dict[str, 'ConnectionT'] | ||
) -> list['qtn.Tensor']: | ||
"""Returns the quimb tensors for the bloq derived from its `on_classical_vals` method. | ||
This function has the same signature as `bloq.my_tensors`, and can be used as a | ||
replacement for it when the bloq has a known classical action. | ||
For example: | ||
```py | ||
class ClassicalBloq(Bloq): | ||
... | ||
def on_classical_vals(...): | ||
... | ||
def my_tensors(self, incoming, outgoing): | ||
return my_tensors_from_classical_action(self, incoming, outgoing) | ||
``` | ||
""" | ||
import quimb.tensor as qtn | ||
|
||
def _signature_to_inds(registers: Iterable['Register'], cxns: dict[str, 'ConnectionT']): | ||
for reg in registers: | ||
for cxn in np.asarray(cxns[reg.name]).flat: | ||
for j in range(reg.dtype.num_qubits): | ||
yield cxn, j | ||
|
||
data = _bloq_to_dense_via_classical_action(bloq) | ||
incoming_inds = _signature_to_inds(bloq.signature.lefts(), incoming) | ||
outgoing_inds = _signature_to_inds(bloq.signature.rights(), outgoing) | ||
inds = [*outgoing_inds, *incoming_inds] | ||
|
||
return [qtn.Tensor(data=data, inds=inds)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
import pytest | ||
import quimb.tensor as qtn | ||
|
||
from qualtran import Bloq, ConnectionT, QAny, QUInt, Signature | ||
from qualtran.bloqs.arithmetic import Add, Xor | ||
from qualtran.bloqs.basic_gates import Toffoli, TwoBitCSwap, XGate | ||
from qualtran.simulation.classical_sim import ClassicalValT | ||
from qualtran.simulation.tensor._tensor_from_classical import ( | ||
bloq_to_dense_via_classical_action, | ||
my_tensors_from_classical_action, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"bloq", [XGate(), TwoBitCSwap(), Toffoli(), Add(QUInt(3)), Xor(QAny(3))], ids=str | ||
) | ||
def test_tensor_consistent_with_classical(bloq: Bloq): | ||
from_classical = bloq_to_dense_via_classical_action(bloq) | ||
from_tensor = bloq.tensor_contract() | ||
|
||
np.testing.assert_allclose(from_classical, from_tensor) | ||
|
||
|
||
class TestClassicalBloq(Bloq): | ||
@property | ||
def signature(self) -> 'Signature': | ||
return Signature.build(a=1, b=1, c=1) | ||
|
||
def on_classical_vals( | ||
self, a: 'ClassicalValT', b: 'ClassicalValT', c: 'ClassicalValT' | ||
) -> dict[str, 'ClassicalValT']: | ||
if a == 1 and b == 1: | ||
c = c ^ 1 | ||
return {'a': a, 'b': b, 'c': c} | ||
|
||
def my_tensors( | ||
self, incoming: dict[str, 'ConnectionT'], outgoing: dict[str, 'ConnectionT'] | ||
) -> list['qtn.Tensor']: | ||
return my_tensors_from_classical_action(self, incoming, outgoing) | ||
|
||
|
||
def test_my_tensors_from_classical_action(): | ||
bloq = TestClassicalBloq() | ||
|
||
expected_tensor = Toffoli().tensor_contract() | ||
actual_tensor = bloq.tensor_contract() | ||
np.testing.assert_allclose(actual_tensor, expected_tensor) |