Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

Optimize uint256_add and uint256_sub #1070

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion kakarot_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@
payload = json.loads(response.text)

chain_id = int(payload["result"], 16)
except (requests.exceptions.ConnectionError, requests.exceptions.MissingSchema):
except (
requests.exceptions.ConnectionError,
requests.exceptions.MissingSchema,
requests.exceptions.InvalidSchema,
):
chain_id = int.from_bytes(b"KKRT", "big")


Expand Down
3 changes: 2 additions & 1 deletion src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from starkware.cairo.common.bool import FALSE
from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.math import unsigned_div_rem, split_int, split_felt
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.uint256 import Uint256, uint256_not, uint256_add, uint256_le
from starkware.cairo.common.uint256 import Uint256, uint256_not, uint256_le
from starkware.cairo.common.math_cmp import is_le
from starkware.cairo.common.math import assert_not_zero
from starkware.starknet.common.syscalls import (
Expand All @@ -29,6 +29,7 @@ from kakarot.interfaces.interfaces import IERC20, IKakarot
from kakarot.errors import Errors
from kakarot.constants import Constants
from utils.eth_transaction import EthTransaction
from utils.uint256 import uint256_add

// @dev: should always be zero for EOAs
@storage_var
Expand Down
4 changes: 2 additions & 2 deletions src/kakarot/instructions/environmental_information.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.memset import memset
from starkware.cairo.common.math import unsigned_div_rem, split_felt
from starkware.cairo.common.math_cmp import is_not_zero, is_le
from starkware.cairo.common.uint256 import Uint256, uint256_le, uint256_add, uint256_eq
from starkware.cairo.common.uint256 import Uint256, uint256_le, uint256_eq

from kakarot.account import Account
from kakarot.interfaces.interfaces import ICairo1Helpers
Expand All @@ -22,7 +22,7 @@ from kakarot.state import State
from kakarot.storages import Kakarot_precompiles_class_hash
from utils.array import slice
from utils.bytes import bytes_to_bytes8_little_endian
from utils.uint256 import uint256_to_uint160
from utils.uint256 import uint256_to_uint160, uint256_add
from utils.utils import Helpers

// @title Environmental information opcodes.
Expand Down
6 changes: 3 additions & 3 deletions src/kakarot/instructions/stop_and_math_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.uint256 import (
uint256_add,
uint256_and,
uint256_eq,
uint256_lt,
Expand All @@ -19,10 +18,11 @@ from starkware.cairo.common.uint256 import (
uint256_shr,
uint256_signed_div_rem,
uint256_signed_lt,
uint256_sub,
uint256_unsigned_div_rem,
uint256_xor,
Uint256,
SHIFT,
ALL_ONES,
)

from kakarot.constants import Constants, opcodes_label
Expand All @@ -31,7 +31,7 @@ from kakarot.evm import EVM
from kakarot.stack import Stack
from kakarot.gas import Gas
from kakarot.state import State
from utils.uint256 import uint256_fast_exp, uint256_signextend
from utils.uint256 import uint256_fast_exp, uint256_signextend, uint256_sub, uint256_add
from utils.utils import Helpers

// @title Stop and Math operations opcodes.
Expand Down
3 changes: 2 additions & 1 deletion src/kakarot/interpreter.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.math_cmp import is_le, is_not_zero, is_nn
from starkware.cairo.common.math import split_felt
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.uint256 import Uint256, uint256_le, uint256_sub, uint256_add
from starkware.cairo.common.uint256 import Uint256, uint256_le
from starkware.cairo.common.math import unsigned_div_rem
from starkware.starknet.common.syscalls import get_tx_info

Expand All @@ -36,6 +36,7 @@ from kakarot.state import State
from kakarot.gas import Gas
from utils.utils import Helpers
from utils.array import count_not_zero
from utils.uint256 import uint256_sub, uint256_add

// @title EVM instructions processing.
// @notice This file contains functions related to the processing of EVM instructions.
Expand Down
3 changes: 2 additions & 1 deletion src/kakarot/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ from starkware.cairo.common.dict import dict_read, dict_write
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.uint256 import Uint256, uint256_add, uint256_sub, uint256_le, uint256_eq
from starkware.cairo.common.uint256 import Uint256, uint256_le, uint256_eq
from starkware.cairo.common.bool import FALSE, TRUE

from kakarot.account import Account
from kakarot.model import model
from kakarot.gas import Gas
from utils.dict import default_dict_copy
from utils.utils import Helpers
from utils.uint256 import uint256_add, uint256_sub

namespace State {
// @dev Create a new empty State
Expand Down
4 changes: 2 additions & 2 deletions src/utils/modexp/modexp_utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ from starkware.cairo.common.uint256 import (
uint256_mul,
uint256_eq,
uint256_lt,
uint256_sub,
uint256_unsigned_div_rem,
uint256_add,
)
from starkware.cairo.common.bitwise import bitwise_and
from starkware.cairo.common.registers import get_label_location
from starkware.cairo.common.bool import FALSE

from utils.uint256 import uint256_sub, uint256_add

// @title ModExpHelpersUint256 Functions
// @notice This file contains a selection of helper functions for modular exponentiation and gas cost calculation.
// @author @dragan2234
Expand Down
95 changes: 93 additions & 2 deletions src/utils/uint256.cairo
Original file line number Diff line number Diff line change
@@ -1,16 +1,107 @@
from starkware.cairo.common.uint256 import (
Uint256,
uint256_eq,
uint256_sub,
uint256_mul,
uint256_unsigned_div_rem,
uint256_le,
uint256_pow2,
uint256_add,
SHIFT,
ALL_ONES,
)
from starkware.cairo.common.math import unsigned_div_rem
from starkware.cairo.common.bool import FALSE

// Adds two integers. Returns the result as a 256-bit integer and the (1-bit) carry.
// Strictly equivalent and faster version of common.uint256.uint256_add using the same whitelisted hint.
func uint256_add{range_check_ptr}(a: Uint256, b: Uint256) -> (res: Uint256, carry: felt) {
alloc_locals;
local carry_low: felt;
local carry_high: felt;
%{
sum_low = ids.a.low + ids.b.low
ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
%}

if (carry_low != 0) {
if (carry_high != 0) {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low - SHIFT, high=a.high + b.high + 1 - SHIFT);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res, 1);
} else {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low - SHIFT, high=a.high + b.high + 1);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res, 0);
}
} else {
if (carry_high != 0) {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low, high=a.high + b.high - SHIFT);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res, 1);
} else {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low, high=a.high + b.high);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res, 0);
}
}
}

// Subtracts two integers. Returns the result as a 256-bit integer.
// Strictly equivalent and faster version of common.uint256.uint256_sub using uint256_add's whitelisted hint.
func uint256_sub{range_check_ptr}(a: Uint256, b: Uint256) -> (res: Uint256) {
alloc_locals;
// Reference "b" as -b.
local b: Uint256 = Uint256(ALL_ONES - b.low + 1, ALL_ONES - b.high);
// Computes a + (-b)
local carry_low: felt;
local carry_high: felt;
%{
sum_low = ids.a.low + ids.b.low
ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
%}

if (carry_low != 0) {
if (carry_high != 0) {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low - SHIFT, high=a.high + b.high + 1 - SHIFT);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res,);
} else {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low - SHIFT, high=a.high + b.high + 1);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res,);
}
} else {
if (carry_high != 0) {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low, high=a.high + b.high - SHIFT);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res,);
} else {
tempvar range_check_ptr = range_check_ptr + 2;
tempvar res = Uint256(low=a.low + b.low, high=a.high + b.high);
assert [range_check_ptr - 2] = res.low;
assert [range_check_ptr - 1] = res.high;
return (res,);
}
}
}

// @notice Internal exponentiation of two 256-bit integers.
// @dev The result is modulo 2^256.
// @param value - The base.
Expand Down
28 changes: 27 additions & 1 deletion tests/src/utils/test_uint256.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.alloc import alloc

from utils.uint256 import uint256_to_uint160
from utils.uint256 import uint256_to_uint160, uint256_add, uint256_sub

func test__uint256_to_uint160{range_check_ptr}() {
// Given
Expand All @@ -23,3 +23,29 @@ func test__uint256_to_uint160{range_check_ptr}() {

return ();
}

func test__uint256_add{range_check_ptr}() -> (felt, felt, felt) {
alloc_locals;
let (a_ptr) = alloc();
let (b_ptr) = alloc();
%{
segments.write_arg(ids.a_ptr, program_input["a"])
segments.write_arg(ids.b_ptr, program_input["b"])
%}
let (res, carry) = uint256_add([cast(a_ptr, Uint256*)], [cast(b_ptr, Uint256*)]);

return (res.low, res.high, carry);
}

func test__uint256_sub{range_check_ptr}() -> Uint256 {
alloc_locals;
let (a_ptr) = alloc();
let (b_ptr) = alloc();
%{
segments.write_arg(ids.a_ptr, program_input["a"])
segments.write_arg(ids.b_ptr, program_input["b"])
%}
let (res) = uint256_sub([cast(a_ptr, Uint256*)], [cast(b_ptr, Uint256*)]);

return res;
}
31 changes: 30 additions & 1 deletion tests/src/utils/test_uint256.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from datetime import timedelta

import pytest
from hypothesis import given, settings
from hypothesis.strategies import integers

from tests.utils.uint256 import int_to_uint256
from tests.utils.uint256 import int_to_uint256, uint256_to_int


class TestUint256:
Expand All @@ -10,3 +14,28 @@ def test_should_cast_value(self, cairo_run, n):
cairo_run(
"test__uint256_to_uint160", x=int_to_uint256(n), expected=n % 2**160
)

class TestUint256Add:
@given(
a=integers(min_value=0, max_value=2**256 - 1),
b=integers(min_value=0, max_value=2**256 - 1),
)
@settings(deadline=timedelta(milliseconds=30_000), max_examples=50)
def test_add(self, cairo_run, a, b):
low, high, carry = cairo_run(
"test__uint256_add", a=int_to_uint256(a), b=int_to_uint256(b)
)
assert uint256_to_int(low, high) == (a + b) % 2**256
assert carry == (a + b) // 2**256

class TestUint256Sub:
@given(
a=integers(min_value=0, max_value=2**256 - 1),
b=integers(min_value=0, max_value=2**256 - 1),
)
@settings(deadline=timedelta(milliseconds=30_000), max_examples=50)
def test_sub(self, cairo_run, a, b):
res = cairo_run(
"test__uint256_sub", a=int_to_uint256(a), b=int_to_uint256(b)
)
assert int(res, 16) == (a - b) % 2**256
8 changes: 1 addition & 7 deletions tests/utils/helpers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.default_dict import default_dict_new
from starkware.cairo.common.math import split_felt
from starkware.cairo.common.memset import memset
from starkware.cairo.common.uint256 import (
Uint256,
uint256_check,
uint256_add,
uint256_eq,
assert_uint256_eq,
)
from starkware.cairo.common.uint256 import Uint256, uint256_check, uint256_eq, assert_uint256_eq
from starkware.cairo.common.dict_access import DictAccess

from backend.starknet import Starknet
Expand Down
Loading