Skip to content

Commit

Permalink
Tensor from classical action (#1514)
Browse files Browse the repository at this point in the history
* WIP tensor from classical action

* fix errors

* fix syntax error?

* move test, add timeout

* rename

* notebooks

* cleanup

* `my_tensors` from classical

* rename and docs

* more tests

* imports + doc

* nits

* fix docstring

* make files private (matching others)

---------

Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
  • Loading branch information
anurudhp and mpharrigan authored Jan 24, 2025
1 parent d92bb1e commit 7bbec3b
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 3 deletions.
2 changes: 1 addition & 1 deletion qualtran/bloqs/arithmetic/sorting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
},
"outputs": [],
"source": [
"comparator = Comparator(7)"
"comparator = Comparator(3)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/arithmetic/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def build_composite_bloq(

@bloq_example
def _comparator() -> Comparator:
comparator = Comparator(7)
comparator = Comparator(3)
return comparator


Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/basic_gates/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _swap_small() -> Swap:
return swap_small


@bloq_example
@bloq_example(generalizer=ignore_split_join)
def _swap_large() -> Swap:
swap_large = Swap(bitsize=64)
return swap_large
Expand Down
5 changes: 5 additions & 0 deletions qualtran/bloqs/basic_gates/swap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_cswap_large,
_cswap_small,
_swap,
_swap_large,
_swap_matrix,
_swap_small,
Swap,
Expand Down Expand Up @@ -224,6 +225,10 @@ def test_swap_small(bloq_autotester):
bloq_autotester(_swap_small)


def test_swap_large(bloq_autotester):
bloq_autotester(_swap_large)


def test_swap_symb(bloq_autotester):
if bloq_autotester.check_name == 'serialize':
pytest.skip("Sympy equality with assumptions.")
Expand Down
4 changes: 4 additions & 0 deletions qualtran/simulation/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@
tensor_out_inp_shape_from_signature,
tensor_shape_from_signature,
)
from ._tensor_from_classical import (
bloq_to_dense_via_classical_action,
my_tensors_from_classical_action,
)
139 changes: 139 additions & 0 deletions qualtran/simulation/tensor/_tensor_from_classical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Iterable, TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import Bloq, ConnectionT, Register
from qualtran.simulation.classical_sim import ClassicalValT


def _bits_to_classical_reg_data(reg: 'Register', bits: NDArray[np.uint8]) -> 'ClassicalValT':
if reg.shape == ():
return reg.dtype.from_bits([*bits.flat])
return reg.dtype.from_bits_array(np.reshape(bits, reg.shape + (reg.dtype.num_qubits,)))


def _bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray:
"""Internal method to compute the tensor of a bloq using its classical action.
Args:
bloq: the Bloq
Returns:
an NDArray of shape (2, 2, ...) indexed by the output bits followed by input bits.
"""
left_qubit_counts = tuple(reg.total_bits() for reg in bloq.signature.lefts())
left_qubit_splits = np.cumsum(left_qubit_counts)

n_qubits_left = sum(left_qubit_counts)
n_qubits_right = sum(reg.total_bits() for reg in bloq.signature.rights())

if n_qubits_left + n_qubits_right > 40:
raise ValueError(f"tensor is too large: {n_qubits_left + n_qubits_right} total qubits")

matrix = np.zeros((2,) * (n_qubits_right + n_qubits_left))

for input_t in itertools.product((0, 1), repeat=n_qubits_left):
*inputs_t, last = np.split(input_t, left_qubit_splits)
assert np.size(last) == 0

input_kwargs = {
reg.name: _bits_to_classical_reg_data(reg, bits)
for reg, bits in zip(bloq.signature.lefts(), inputs_t)
}
output_args = bloq.call_classically(**input_kwargs)

if output_args:
output_t = np.concatenate(
[
reg.dtype.to_bits_array(np.asarray(vals)).flat
for reg, vals in zip(bloq.signature.rights(), output_args)
]
)
else:
output_t = np.array([])

matrix[tuple([*np.atleast_1d(output_t), *np.atleast_1d(input_t)])] = 1

return matrix


def bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray:
"""Return a contracted, dense ndarray representing the bloq, using its classical action.
Args:
bloq: The bloq
Raises:
ValueError: if the bloq does not have a classical action.
"""
try:
matrix = _bloq_to_dense_via_classical_action(bloq)
except ValueError as e:
raise ValueError(f"cannot compute tensor for {bloq}: {str(e)}") from e

n_qubits_left = sum(reg.total_bits() for reg in bloq.signature.lefts())
n_qubits_right = sum(reg.total_bits() for reg in bloq.signature.rights())

shape: tuple[int, ...]
if n_qubits_left == 0 and n_qubits_right == 0:
shape = ()
elif n_qubits_left == 0 or n_qubits_right == 0:
shape = (2 ** max(n_qubits_left, n_qubits_right),)
else:
shape = (2**n_qubits_right, 2**n_qubits_left)

return matrix.reshape(shape)


def my_tensors_from_classical_action(
bloq: 'Bloq', incoming: dict[str, 'ConnectionT'], outgoing: dict[str, 'ConnectionT']
) -> list['qtn.Tensor']:
"""Returns the quimb tensors for the bloq derived from its `on_classical_vals` method.
This function has the same signature as `bloq.my_tensors`, and can be used as a
replacement for it when the bloq has a known classical action.
For example:
```py
class ClassicalBloq(Bloq):
...
def on_classical_vals(...):
...
def my_tensors(self, incoming, outgoing):
return my_tensors_from_classical_action(self, incoming, outgoing)
```
"""
import quimb.tensor as qtn

def _signature_to_inds(registers: Iterable['Register'], cxns: dict[str, 'ConnectionT']):
for reg in registers:
for cxn in np.asarray(cxns[reg.name]).flat:
for j in range(reg.dtype.num_qubits):
yield cxn, j

data = _bloq_to_dense_via_classical_action(bloq)
incoming_inds = _signature_to_inds(bloq.signature.lefts(), incoming)
outgoing_inds = _signature_to_inds(bloq.signature.rights(), outgoing)
inds = [*outgoing_inds, *incoming_inds]

return [qtn.Tensor(data=data, inds=inds)]
61 changes: 61 additions & 0 deletions qualtran/simulation/tensor/_tensor_from_classical_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import quimb.tensor as qtn

from qualtran import Bloq, ConnectionT, QAny, QUInt, Signature
from qualtran.bloqs.arithmetic import Add, Xor
from qualtran.bloqs.basic_gates import Toffoli, TwoBitCSwap, XGate
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.simulation.tensor._tensor_from_classical import (
bloq_to_dense_via_classical_action,
my_tensors_from_classical_action,
)


@pytest.mark.parametrize(
"bloq", [XGate(), TwoBitCSwap(), Toffoli(), Add(QUInt(3)), Xor(QAny(3))], ids=str
)
def test_tensor_consistent_with_classical(bloq: Bloq):
from_classical = bloq_to_dense_via_classical_action(bloq)
from_tensor = bloq.tensor_contract()

np.testing.assert_allclose(from_classical, from_tensor)


class TestClassicalBloq(Bloq):
@property
def signature(self) -> 'Signature':
return Signature.build(a=1, b=1, c=1)

def on_classical_vals(
self, a: 'ClassicalValT', b: 'ClassicalValT', c: 'ClassicalValT'
) -> dict[str, 'ClassicalValT']:
if a == 1 and b == 1:
c = c ^ 1
return {'a': a, 'b': b, 'c': c}

def my_tensors(
self, incoming: dict[str, 'ConnectionT'], outgoing: dict[str, 'ConnectionT']
) -> list['qtn.Tensor']:
return my_tensors_from_classical_action(self, incoming, outgoing)


def test_my_tensors_from_classical_action():
bloq = TestClassicalBloq()

expected_tensor = Toffoli().tensor_contract()
actual_tensor = bloq.tensor_contract()
np.testing.assert_allclose(actual_tensor, expected_tensor)

0 comments on commit 7bbec3b

Please sign in to comment.