Skip to content

Commit

Permalink
feat: Verify mul, div, rem constraints with z3 (DelphinusLab#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
suzuhara030 authored Jul 31, 2023
1 parent cbaad78 commit 00cb4b0
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 4 deletions.
4 changes: 2 additions & 2 deletions smt/addi64.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
lhs, rhs, res = Ints('lhs rhs res')
s.add(is_i64(lhs)),
s.add(is_i64(rhs)),
s.add(is_i64(res)),

wasm_add_i64 = Function('WamsAddI64', IntSort(), IntSort(), IntSort())
wasm_add_i64 = Function('WasmAddI64', IntSort(), IntSort(), IntSort())
s.add(ForAll([lhs, rhs], wasm_add_i64(lhs, rhs) == (lhs + rhs) % I64_MODULUS))

# define var
Expand All @@ -19,6 +18,7 @@
constrain = Function('Constrain', IntSort(), BoolSort())
constraints = [
is_bit(overflow),
is_i64(res),
# c.bin.add
fr_sub(fr_sub(fr_add(fr_mul(overflow, I64_MODULUS), res), rhs), lhs) == 0
]
Expand Down
75 changes: 75 additions & 0 deletions smt/div_u_rem_u_i64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from z3 import *
from utils import *
from functools import reduce

s = init_z3_solver()

# define spec
lhs, rhs, res_d, res_m = Ints('lhs rhs res_d res_m')
s.add(is_i64(lhs))
s.add(is_i64(rhs))
s.add(rhs != 0)

wasm_div_u_i64 = Function('WasmDivUI64', IntSort(), IntSort(), IntSort())
s.add(ForAll([lhs, rhs], wasm_div_u_i64(lhs, rhs) == lhs / rhs))
wasm_rem_u_i64 = Function('WasmRemUI64', IntSort(), IntSort(), IntSort())
s.add(ForAll([lhs, rhs], wasm_rem_u_i64(lhs, rhs) == lhs % rhs))

# define var
aux1 = Int('aux1')
aux2 = Int('aux2')
aux3 = Int('aux3')
intermediate = Int('intermediate')

constrain = Function('Constrain', IntSort(), IntSort(),
IntSort(), IntSort(), BoolSort())
constraints = [
is_i64(aux1),
is_i64(aux2),
is_i64(aux3),
is_i64(res_d),
is_i64(res_m),
# c.bin.div_u/rem_u
intermediate == fr_add(fr_mul(rhs, aux1), aux2),
fr_sub(intermediate, lhs) == 0,
fr_sub(fr_add(fr_add(aux2, aux3), 1), rhs) == 0,
fr_sub(res_d, aux1) == 0,
fr_sub(res_m, aux2) == 0,
]

s.add(ForAll([aux1, aux2, aux3, intermediate], And(constrain(
aux1, aux2, aux3, intermediate) == reduce(lambda x, y: And(x, y), constraints))))

s.push()
# Soundness
s.add(And(constrain(aux1, aux2, aux3, intermediate),
Or(wasm_div_u_i64(lhs, rhs) != res_d, wasm_rem_u_i64(lhs, rhs) != res_m)))

check_res = s.check()
print('--------------Soundness---------------')
if check_res.r == Z3_L_TRUE:
print('Verify: Fail')
print(s.model())
elif check_res.r == Z3_L_FALSE:
print('Verify: Pass')
else:
print('Verify: Fail')
s.pop()


s.push()
# Completeness
s.add(And(wasm_div_u_i64(lhs, rhs) == res_d, wasm_rem_u_i64(lhs, rhs) == res_m,
ForAll([aux1, aux2, aux3, intermediate],
Not(constrain(aux1, aux2, aux3, intermediate)))))

check_res = s.check()
print('-------------Completeness----------------')
if check_res.r == Z3_L_TRUE:
print('Verify: Fail')
print(s.model())
elif check_res.r == Z3_L_FALSE:
print('Verify: Pass')
else:
print('Verify: Fail')
s.pop()
63 changes: 63 additions & 0 deletions smt/muli64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from z3 import *
from utils import *
from functools import reduce

s = init_z3_solver()

# define spec
lhs, rhs, res = Ints('lhs rhs res')
s.add(is_i64(lhs))
s.add(is_i64(rhs))

wasm_mul_i64 = Function('WasmMulI64', IntSort(), IntSort(), IntSort())
s.add(ForAll([lhs, rhs], wasm_mul_i64(lhs, rhs) == (lhs * rhs) % I64_MODULUS))

# define var
aux = Int('aux')
intermediate1, intermediate2 = Ints('intermediate1 intermediate2')

constrain = Function('Constrain', IntSort(), IntSort(), IntSort(), BoolSort())
constraints = [
is_i64(aux),
is_i64(res),
# c.bin.mul
intermediate1 == fr_mul(rhs, lhs),
intermediate2 == fr_mul(aux, I64_MODULUS),
fr_sub(fr_sub(intermediate1, intermediate2), res) == 0
]

s.add(ForAll([aux, intermediate1, intermediate2], And(constrain(aux, intermediate1, intermediate2) == reduce(
lambda x, y: And(x, y), constraints))))

s.push()
# Soundness
s.add(And(constrain(aux, intermediate1,
intermediate2), wasm_mul_i64(lhs, rhs) != res))

check_res = s.check()
print('--------------Soundness---------------')
if check_res.r == Z3_L_TRUE:
print('Verify: Fail')
print(s.model())
elif check_res.r == Z3_L_FALSE:
print('Verify: Pass')
else:
print('Verify: Fail')
s.pop()


s.push()
# Completeness
s.add(And(wasm_mul_i64(lhs, rhs) == res, ForAll(
[aux, intermediate1, intermediate2], Not(constrain(aux, intermediate1, intermediate2)))))

check_res = s.check()
print('-------------Completeness----------------')
if check_res.r == Z3_L_TRUE:
print('Verify: Fail')
print(s.model())
elif check_res.r == Z3_L_FALSE:
print('Verify: Pass')
else:
print('Verify: Fail')
s.pop()
4 changes: 2 additions & 2 deletions smt/subi64.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
lhs, rhs, res = Ints('lhs rhs res')
s.add(is_i64(lhs))
s.add(is_i64(rhs))
s.add(is_i64(res))

wasm_sub_i64 = Function('WamsSubI64', IntSort(), IntSort(), IntSort())
wasm_sub_i64 = Function('WasmSubI64', IntSort(), IntSort(), IntSort())
s.add(ForAll([lhs, rhs], wasm_sub_i64(lhs, rhs) == (lhs - rhs + I64_MODULUS) % I64_MODULUS))

# define var
Expand All @@ -19,6 +18,7 @@
constrain = Function('Constrain', IntSort(), BoolSort())
constraints = [
is_bit(overflow),
is_i64(res),
# c.bin.sub
fr_sub(fr_sub(fr_add(fr_mul(overflow, I64_MODULUS), lhs), res), rhs) == 0
]
Expand Down

0 comments on commit 00cb4b0

Please sign in to comment.