Skip to content

Commit

Permalink
restrict classical action of certain arithmetic bloqs (#1518)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri authored Jan 21, 2025
1 parent 1bd4f70 commit b3c8707
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
15 changes: 10 additions & 5 deletions qualtran/bloqs/arithmetic/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np
Expand Down Expand Up @@ -40,6 +39,7 @@
from qualtran.bloqs.bookkeeping import Allocate, Cast, Free
from qualtran.bloqs.mcmt.multi_target_cnot import MultiTargetCNOT
from qualtran.drawing import Text
from qualtran.simulation.classical_sim import add_ints

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
Expand Down Expand Up @@ -270,10 +270,15 @@ def signature(self):
def on_classical_vals(
self, a: 'ClassicalValT', b: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
unsigned = isinstance(self.dtype, (QUInt, QMontgomeryUInt))
bitsize = self.dtype.bitsize
N = 2**bitsize if unsigned else 2 ** (bitsize - 1)
return {'a': a, 'b': int(math.fmod(b - a, N))}
return {
'a': a,
'b': add_ints(
int(b),
-int(a),
num_bits=int(self.dtype.bitsize),
is_signed=isinstance(self.dtype, QInt),
),
}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
Expand Down
9 changes: 9 additions & 0 deletions qualtran/bloqs/arithmetic/subtraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,12 @@ def test_subtract_from_bloq_decomposition():
want[(a << 4) | c][a_b] = 1
got = gate.tensor_contract()
np.testing.assert_allclose(got, want)


@pytest.mark.parametrize('bitsize', range(2, 5))
def test_subtractfrom_classical_action(bitsize):
dtype = QInt(bitsize)
blq = SubtractFrom(dtype)
qlt_testing.assert_consistent_classical_action(
blq, a=tuple(dtype.get_classical_domain()), b=tuple(dtype.get_classical_domain())
)
30 changes: 28 additions & 2 deletions qualtran/bloqs/mod_arithmetic/mod_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': (x + y) % self.mod}
if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
)
if not (0 <= y < self.mod):
raise ValueError(
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
)

y = (x + y) % self.mod
return {'x': x, 'y': y}

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if is_symbolic(self.bitsize):
Expand Down Expand Up @@ -307,6 +317,12 @@ def on_classical_vals(
return {'ctrl': 0, 'x': x}

assert ctrl == 1, 'Bad ctrl value.'

if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
)

x = (x + self.k) % self.mod
return {'ctrl': ctrl, 'x': x}

Expand Down Expand Up @@ -492,7 +508,17 @@ def on_classical_vals(
if ctrl != self.cv:
return {'ctrl': ctrl, 'x': x, 'y': y}

return {'ctrl': ctrl, 'x': x, 'y': (x + y) % self.mod}
if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
)
if not (0 <= y < self.mod):
raise ValueError(
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
)

y = (x + y) % self.mod
return {'ctrl': ctrl, 'x': x, 'y': y}

def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl, x: Soquet, y: Soquet
Expand Down
12 changes: 12 additions & 0 deletions qualtran/bloqs/mod_arithmetic/mod_addition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,15 @@ def test_cmod_add_complexity_vs_ref():
def test_mod_add_classical_action(bitsize, prime):
b = ModAdd(bitsize, prime)
assert_consistent_classical_action(b, x=range(prime), y=range(prime))


def test_cmodadd_tensor():
blq = CModAddK(bitsize=4, mod=7, k=1)
want = np.zeros((7, 7))
for i in range(7):
j = (i + 1) % 7
want[j, i] = 1

tn = blq.tensor_contract()
np.testing.assert_allclose(tn[:7, :7], np.eye(7)) # ctrl = 0
np.testing.assert_allclose(tn[16 : 16 + 7, 16 : 16 + 7], want) # ctrl = 1

0 comments on commit b3c8707

Please sign in to comment.