1
1
use crate :: {
2
- primitives:: U256 , Error , Precompile , PrecompileAddress , PrecompileResult , StandardPrecompileFn ,
2
+ primitives:: U256 ,
3
+ utilities:: { get_right_padded, get_right_padded_vec, left_padding, left_padding_vec} ,
4
+ Error , Precompile , PrecompileAddress , PrecompileResult , StandardPrecompileFn ,
3
5
} ;
4
6
use alloc:: vec:: Vec ;
5
- use core:: {
6
- cmp:: { max, min, Ordering } ,
7
- mem:: size_of,
8
- } ;
9
- use num:: { BigUint , One , Zero } ;
7
+ use aurora_engine_modexp:: modexp;
8
+ use core:: cmp:: { max, min} ;
10
9
11
10
pub const BYZANTIUM : PrecompileAddress = PrecompileAddress (
12
11
crate :: u64_to_address ( 5 ) ,
@@ -32,121 +31,96 @@ pub fn berlin_run(input: &[u8], gas_limit: u64) -> PrecompileResult {
32
31
} )
33
32
}
34
33
35
- fn calculate_iteration_count ( exp_length : u64 , exp_highp : & BigUint ) -> u64 {
34
+ fn calculate_iteration_count ( exp_length : u64 , exp_highp : & U256 ) -> u64 {
36
35
let mut iteration_count: u64 = 0 ;
37
36
38
- if exp_length <= 32 && exp_highp. is_zero ( ) {
37
+ if exp_length <= 32 && * exp_highp == U256 :: ZERO {
39
38
iteration_count = 0 ;
40
39
} else if exp_length <= 32 {
41
- iteration_count = exp_highp. bits ( ) - 1 ;
40
+ iteration_count = exp_highp. bit_len ( ) as u64 - 1 ;
42
41
} else if exp_length > 32 {
43
- iteration_count = ( 8 * ( exp_length - 32 ) ) + max ( 1 , exp_highp. bits ( ) ) - 1 ;
42
+ iteration_count = ( 8 * ( exp_length - 32 ) ) + max ( 1 , exp_highp. bit_len ( ) as u64 ) - 1 ;
44
43
}
45
44
46
45
max ( iteration_count, 1 )
47
46
}
48
47
49
- macro_rules! read_u64_with_overflow {
50
- ( $input: expr, $from: expr, $to: expr, $overflow_limit: expr) => { {
51
- const SPLIT : usize = 32 - size_of:: <u64 >( ) ;
52
- let len = $input. len( ) ;
53
- let from_zero = min( $from, len) ;
54
- let from = min( from_zero + SPLIT , len) ;
55
- let to = min( $to, len) ;
56
- let overflow_bytes = & $input[ from_zero..from] ;
57
-
58
- let mut len_bytes = [ 0u8 ; size_of:: <u64 >( ) ] ;
59
- len_bytes[ ..to - from] . copy_from_slice( & $input[ from..to] ) ;
60
- let out = u64 :: from_be_bytes( len_bytes) as usize ;
61
- let overflow = !( out < $overflow_limit && overflow_bytes. iter( ) . all( |& x| x == 0 ) ) ;
62
- ( out, overflow)
63
- } } ;
64
- }
65
-
66
48
fn run_inner < F > ( input : & [ u8 ] , gas_limit : u64 , min_gas : u64 , calc_gas : F ) -> PrecompileResult
67
49
where
68
- F : FnOnce ( u64 , u64 , u64 , & BigUint ) -> u64 ,
50
+ F : FnOnce ( u64 , u64 , u64 , & U256 ) -> u64 ,
69
51
{
70
- let len = input. len ( ) ;
71
- let ( base_len, base_overflow) = read_u64_with_overflow ! ( input, 0 , 32 , u32 :: MAX as usize ) ;
72
- let ( exp_len, exp_overflow) = read_u64_with_overflow ! ( input, 32 , 64 , u32 :: MAX as usize ) ;
73
- let ( mod_len, mod_overflow) = read_u64_with_overflow ! ( input, 64 , 96 , u32 :: MAX as usize ) ;
74
-
75
- if base_overflow || mod_overflow {
76
- return Err ( Error :: ModexpBaseOverflow ) ;
52
+ // If there is no minimum gas, return error.
53
+ if min_gas > gas_limit {
54
+ return Err ( Error :: OutOfGas ) ;
77
55
}
78
-
79
- if mod_overflow {
56
+ // The format of input is:
57
+ // <length_of_BASE> <length_of_EXPONENT> <length_of_MODULUS> <BASE> <EXPONENT> <MODULUS>
58
+ // Where every length is a 32-byte left-padded integer representing the number of bytes
59
+ // to be taken up by the next value
60
+ const HEADER_LENGTH : usize = 96 ;
61
+
62
+ // Extract the header.
63
+ let base_len = U256 :: from_be_bytes ( get_right_padded :: < 32 > ( input, 0 ) ) ;
64
+ let exp_len = U256 :: from_be_bytes ( get_right_padded :: < 32 > ( input, 32 ) ) ;
65
+ let mod_len = U256 :: from_be_bytes ( get_right_padded :: < 32 > ( input, 64 ) ) ;
66
+
67
+ // cast base and modulus to usize, it does not make sense to handle larger values
68
+ let Ok ( base_len) = usize:: try_from ( base_len) else {
69
+ return Err ( Error :: ModexpBaseOverflow ) ;
70
+ } ;
71
+ let Ok ( mod_len) = usize:: try_from ( mod_len) else {
80
72
return Err ( Error :: ModexpModOverflow ) ;
73
+ } ;
74
+
75
+ // Handle a special case when both the base and mod length is zero
76
+ if base_len == 0 && mod_len == 0 {
77
+ return Ok ( ( min_gas, Vec :: new ( ) ) ) ;
81
78
}
82
79
83
- let ( r, gas_cost) = if base_len == 0 && mod_len == 0 {
84
- if min_gas > gas_limit {
85
- return Err ( Error :: OutOfGas ) ;
86
- }
87
- ( BigUint :: zero ( ) , min_gas)
88
- } else {
89
- // set limit for exp overflow
90
- if exp_overflow {
91
- return Err ( Error :: ModexpExpOverflow ) ;
92
- }
93
- let base_start = 96 ;
94
- let base_end = base_start + base_len;
95
- let exp_end = base_end + exp_len;
96
- let exp_highp_end = base_end + min ( 32 , exp_len) ;
97
- let mod_end = exp_end + mod_len;
98
-
99
- let exp_highp = {
100
- let mut out = [ 0 ; 32 ] ;
101
- let from = min ( base_end, len) ;
102
- let to = min ( exp_highp_end, len) ;
103
- let target_from = 32 - ( exp_highp_end - base_end) ; // 32 - exp length
104
- let target_to = target_from + ( to - from) ; // beginning + size to copy
105
- out[ target_from..target_to] . copy_from_slice ( & input[ from..to] ) ;
106
- BigUint :: from_bytes_be ( & out)
107
- } ;
108
-
109
- let gas_cost = calc_gas ( base_len as u64 , exp_len as u64 , mod_len as u64 , & exp_highp) ;
110
- if gas_cost > gas_limit {
111
- return Err ( Error :: OutOfGas ) ;
112
- }
80
+ // cast exponent length to usize, it does not make sense to handle larger values.
81
+ let Ok ( exp_len) = usize:: try_from ( exp_len) else {
82
+ return Err ( Error :: ModexpModOverflow ) ;
83
+ } ;
113
84
114
- let read_big = |from : usize , to : usize | {
115
- let mut out = vec ! [ 0 ; to - from] ;
116
- let from = min ( from, len) ;
117
- let to = min ( to, len) ;
118
- out[ ..to - from] . copy_from_slice ( & input[ from..to] ) ;
119
- BigUint :: from_bytes_be ( & out)
120
- } ;
85
+ // Used to extract ADJUSTED_EXPONENT_LENGTH.
86
+ let exp_highp_len = min ( exp_len, 32 ) ;
121
87
122
- let base = read_big ( base_start, base_end) ;
123
- let exponent = read_big ( base_end, exp_end) ;
124
- let modulus = read_big ( exp_end, mod_end) ;
88
+ // throw away the header data as we already extracted lengths.
89
+ let input = if input. len ( ) >= 96 {
90
+ & input[ HEADER_LENGTH ..]
91
+ } else {
92
+ // or set input to zero if there is no more data
93
+ & [ ]
94
+ } ;
125
95
126
- if modulus. is_zero ( ) || modulus. is_one ( ) {
127
- ( BigUint :: zero ( ) , gas_cost)
128
- } else {
129
- ( base. modpow ( & exponent, & modulus) , gas_cost)
130
- }
96
+ let exp_highp = {
97
+ // get right padded bytes so if data.len is less then exp_len we will get right padded zeroes.
98
+ let right_padded_highp = get_right_padded :: < 32 > ( input, base_len) ;
99
+ // If exp_len is less then 32 bytes get only exp_len bytes and do left padding.
100
+ let out = left_padding :: < 32 > ( & right_padded_highp[ ..exp_highp_len] ) ;
101
+ U256 :: from_be_bytes ( out)
131
102
} ;
132
103
133
- // write output to given memory, left padded and same length as the modulus.
134
- let bytes = r. to_bytes_be ( ) ;
135
- // always true except in the case of zero-length modulus, which leads to
136
- // output of length and value 1.
137
- match bytes. len ( ) . cmp ( & mod_len) {
138
- Ordering :: Equal => Ok ( ( gas_cost, bytes) ) ,
139
- Ordering :: Less => {
140
- let mut ret = Vec :: with_capacity ( mod_len) ;
141
- ret. extend ( core:: iter:: repeat ( 0 ) . take ( mod_len - bytes. len ( ) ) ) ;
142
- ret. extend_from_slice ( & bytes[ ..] ) ;
143
- Ok ( ( gas_cost, ret) )
144
- }
145
- Ordering :: Greater => Ok ( ( gas_cost, Vec :: new ( ) ) ) ,
104
+ // calculate gas spent.
105
+ let gas_cost = calc_gas ( base_len as u64 , exp_len as u64 , mod_len as u64 , & exp_highp) ;
106
+ // check if we have enough gas.
107
+ if gas_cost > gas_limit {
108
+ return Err ( Error :: OutOfGas ) ;
146
109
}
110
+
111
+ // Padding is needed if the input does not contain all 3 values.
112
+ let base = get_right_padded_vec ( input, 0 , base_len) ;
113
+ let exponent = get_right_padded_vec ( input, base_len, exp_len) ;
114
+ let modulus = get_right_padded_vec ( input, base_len. saturating_add ( exp_len) , mod_len) ;
115
+
116
+ // Call the modexp.
117
+ let output = modexp ( & base, & exponent, & modulus) ;
118
+
119
+ // left pad the result to modulus length. bytes will always by less or equal to modulus length.
120
+ Ok ( ( gas_cost, left_padding_vec ( & output, mod_len) ) )
147
121
}
148
122
149
- fn byzantium_gas_calc ( base_len : u64 , exp_len : u64 , mod_len : u64 , exp_highp : & BigUint ) -> u64 {
123
+ fn byzantium_gas_calc ( base_len : u64 , exp_len : u64 , mod_len : u64 , exp_highp : & U256 ) -> u64 {
150
124
// ouput of this function is bounded by 2^128
151
125
fn mul_complexity ( x : u64 ) -> U256 {
152
126
if x <= 64 {
@@ -175,7 +149,7 @@ fn byzantium_gas_calc(base_len: u64, exp_len: u64, mod_len: u64, exp_highp: &Big
175
149
176
150
// Calculate gas cost according to EIP 2565:
177
151
// https://eips.ethereum.org/EIPS/eip-2565
178
- fn berlin_gas_calc ( base_length : u64 , exp_length : u64 , mod_length : u64 , exp_highp : & BigUint ) -> u64 {
152
+ fn berlin_gas_calc ( base_length : u64 , exp_length : u64 , mod_length : u64 , exp_highp : & U256 ) -> u64 {
179
153
fn calculate_multiplication_complexity ( base_length : u64 , mod_length : u64 ) -> U256 {
180
154
let max_length = max ( base_length, mod_length) ;
181
155
let mut words = max_length / 8 ;
0 commit comments