Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in KaliskiStep3 and add tests for all steps #1496

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,20 +1462,17 @@ def on_classical_vals(
c: Optional['ClassicalValT'] = None,
target: Optional['ClassicalValT'] = None,
) -> Dict[str, 'ClassicalValT']:
if self._op_symbol in ('>', '<='):
c_val = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
else:
c_val = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
if self.uncompute:
assert c == add_ints(
int(a),
int(b),
num_bits=int(self.dtype.bitsize),
is_signed=isinstance(self.dtype, QInt),
)
assert c == c_val
assert target == self._classical_comparison(a, b)
return {'a': a, 'b': b}
if self._op_symbol in ('>', '<='):
c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
else:
c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))}
assert c is None
assert target is None
return {'a': a, 'b': b, 'c': c_val, 'target': int(self._classical_comparison(a, b))}

def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']:
if self._op_symbol in ('>', '<='):
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/factoring/ecc/ec_add_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def test_ec_add_symbolic_cost():
# toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n
# 3-controlled toffolis in step 2. The expression is written with rationals because sympy
# comparison fails with floats.
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(391, 2) * n - 31
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31


def test_ec_add(bloq_autotester):
Expand Down
23 changes: 11 additions & 12 deletions qualtran/bloqs/mod_arithmetic/mod_division.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, v: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
print('here')
assert False
m ^= f & (v == 0)
assert is_terminal == 0
is_terminal ^= m
Expand Down Expand Up @@ -101,10 +99,10 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.bitsize):
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize)
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize + 1)
else:
cvs = [0] * int(self.bitsize)
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 2}
cvs = [0] * int(self.bitsize) + [1]
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 3}


@frozen
Expand Down Expand Up @@ -197,25 +195,25 @@ def on_classical_vals(
def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet
) -> Dict[str, 'SoquetT']:
u, v, junk, greater_than = bb.add(
u, v, junk_c, greater_than = bb.add(
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)), a=u, b=v
)

(greater_than, f, b), junk, ctrl = bb.add(
(greater_than, f, b), junk_m, ctrl = bb.add(
MultiAnd(cvs=(1, 1, 0)), ctrl=(greater_than, f, b)
)

ctrl, a = bb.add(CNOT(), ctrl=ctrl, target=a)
ctrl, m = bb.add(CNOT(), ctrl=ctrl, target=m)

greater_than, f, b = bb.add(
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk, target=ctrl
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk_m, target=ctrl
)
u, v = bb.add(
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(),
a=u,
b=v,
c=junk,
c=junk_c,
target=greater_than,
)
return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f}
Expand Down Expand Up @@ -391,7 +389,7 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
CNOT(): 4,
CNOT(): 3,
XGate(): 2,
ModDbl(QMontgomeryUInt(self.bitsize), self.mod): 1,
CSwapApprox(self.bitsize): 2,
Expand Down Expand Up @@ -475,7 +473,7 @@ def on_classical_vals(
of `f` and `m`.
"""
assert m == 0
is_terminal = f == 1 and v == 0
is_terminal = int(f == 1 and v == 0)
if f == 0:
# When `f = 0` this means that the algorithm is nearly over and that we just need to
# double the value of `r`.
Expand All @@ -489,7 +487,8 @@ def on_classical_vals(
f = 0
r = (r << 1) % self.mod
else:
m = (u % 2 == 1) & (v % 2 == 0)
m = ((u % 2 == 1) & (v % 2 == 0)) or (u % 2 == 1 and v % 2 == 1 and u > v)
m = int(m)
# Kaliski iteration as described in Fig7 of https://arxiv.org/pdf/2001.09580.
swap = (u % 2 == 0 and v % 2 == 1) or (u % 2 == 1 and v % 2 == 1 and u > v)
if swap:
Expand Down
82 changes: 79 additions & 3 deletions qualtran/bloqs/mod_arithmetic/mod_division_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import qualtran.testing as qlt_testing
from qualtran import QMontgomeryUInt
from qualtran.bloqs.mod_arithmetic import mod_division
from qualtran.bloqs.mod_arithmetic.mod_division import _kaliskimodinverse_example, KaliskiModInverse
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
Expand All @@ -36,7 +37,7 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
continue
x_montgomery = dtype.uint_to_montgomery(x, mod)
res = blq.call_classically(x=x_montgomery)
print(x, x_montgomery)

assert res == cblq.call_classically(x=x_montgomery)
assert len(res) == 2
assert res[0] == dtype.montgomery_inverse(x_montgomery, mod)
Expand Down Expand Up @@ -85,11 +86,11 @@ def test_kaliski_symbolic_cost():
# construction this is just $n-1$ (BitwiseNot -> Add(p+1)).
# - The cost of an iteration in Litinski $13n$ since they ignore constants.
# Our construction is exactly the same but we also count the constants
# which amout to $3$. for a total cost of $13n + 3$.
# which amout to $3$. for a total cost of $13n + 4$.
# For example the cost of ModDbl is 2n+1. In their figure 8, they report
# it as just $2n$. ModDbl gets executed within the 2n loop so its contribution
# to the overal cost should be 4n^2 + 2n instead of just 4n^2.
assert total_toff == 26 * n**2 + 7 * n - 1
assert total_toff == 26 * n**2 + 9 * n - 1


def test_kaliskimodinverse_example(bloq_autotester):
Expand All @@ -99,3 +100,78 @@ def test_kaliskimodinverse_example(bloq_autotester):
@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('mod_division')


def test_kaliski_iteration_decomposition():
mod = 7
bitsize = 5
b = mod_division._KaliskiIteration(bitsize, mod)
cb = b.decompose_bloq()
for x in range(mod):
u = mod
v = x
r = 0
s = 1
f = 1

for _ in range(2 * bitsize):
inputs = {'u': u, 'v': v, 'r': r, 's': s, 'm': 0, 'f': f, 'is_terminal': 0}
res = b.call_classically(**inputs)
assert res == cb.call_classically(**inputs), f'{inputs=}'
u, v, r, s, _, f, _ = res # type: ignore

qlt_testing.assert_valid_bloq_decomposition(b)
qlt_testing.assert_equivalent_bloq_counts(b, generalizer=(ignore_alloc_free, ignore_split_join))


def test_kaliski_steps():
bitsize = 5
mod = 7
steps = [
mod_division._KaliskiIterationStep1(bitsize),
mod_division._KaliskiIterationStep2(bitsize),
mod_division._KaliskiIterationStep3(bitsize),
mod_division._KaliskiIterationStep4(bitsize),
mod_division._KaliskiIterationStep5(bitsize),
mod_division._KaliskiIterationStep6(bitsize, mod),
]
csteps = [b.decompose_bloq() for b in steps]

# check decomposition is valid.
for step in steps:
qlt_testing.assert_valid_bloq_decomposition(step)
qlt_testing.assert_equivalent_bloq_counts(
step, generalizer=(ignore_alloc_free, ignore_split_join)
)

# check that for all inputs all 2n iteration work when excuted directly on the 6 steps
# and their decompositions.
for x in range(mod):
u, v, r, s, f = mod, x, 0, 1, 1

for _ in range(2 * bitsize):
a = b = m = is_terminal = 0

res = steps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
assert res == csteps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
v, m, f, is_terminal = res # type: ignore

res = steps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
assert res == csteps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
u, v, b, a, m, f = res # type: ignore

res = steps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
assert res == csteps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
u, v, b, a, m, f = res # type: ignore

res = steps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
assert res == csteps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
u, v, r, s, a = res # type: ignore

res = steps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
assert res == csteps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
u, v, r, s, b, f = res # type: ignore

res = steps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
assert res == csteps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
u, v, r, s, b, a, m, f = res # type: ignore
Loading