1- // Static cost analysis for Clarity expressions
1+ // Static cost analysis for Clarity contracts
22
3- use std:: collections:: HashMap ;
3+ use std:: collections:: { HashMap , HashSet } ;
44
55use clarity_types:: types:: { CharType , SequenceData , TraitIdentifier } ;
66use stacks_common:: types:: StacksEpochId ;
@@ -48,7 +48,7 @@ pub enum CostExprNode {
4848 FieldIdentifier ( TraitIdentifier ) ,
4949 TraitReference ( ClarityName ) ,
5050 // User function arguments
51- UserArgument ( ClarityName , ClarityName ) , // (argument_name, argument_type)
51+ UserArgument ( ClarityName , SymbolicExpressionType ) , // (argument_name, argument_type)
5252 // User-defined functions
5353 UserFunction ( ClarityName ) ,
5454}
@@ -94,7 +94,7 @@ impl StaticCost {
9494#[ derive( Debug , Clone ) ]
9595pub struct UserArgumentsContext {
9696 /// Map from argument name to argument type
97- pub arguments : HashMap < ClarityName , ClarityName > ,
97+ pub arguments : HashMap < ClarityName , SymbolicExpressionType > ,
9898}
9999
100100impl UserArgumentsContext {
@@ -104,15 +104,15 @@ impl UserArgumentsContext {
104104 }
105105 }
106106
107- pub fn add_argument ( & mut self , name : ClarityName , arg_type : ClarityName ) {
107+ pub fn add_argument ( & mut self , name : ClarityName , arg_type : SymbolicExpressionType ) {
108108 self . arguments . insert ( name, arg_type) ;
109109 }
110110
111111 pub fn is_user_argument ( & self , name : & ClarityName ) -> bool {
112112 self . arguments . contains_key ( name)
113113 }
114114
115- pub fn get_argument_type ( & self , name : & ClarityName ) -> Option < & ClarityName > {
115+ pub fn get_argument_type ( & self , name : & ClarityName ) -> Option < & SymbolicExpressionType > {
116116 self . arguments . get ( name)
117117 }
118118}
@@ -203,7 +203,8 @@ fn make_ast(
203203 Ok ( ast)
204204}
205205
206- /// somewhat of a passthrough since we don't have to build the whole context we can jsut return the cost of the single expression
206+ /// somewhat of a passthrough since we don't have to build the whole context we
207+ /// can jsut return the cost of the single expression
207208fn static_cost_native (
208209 source : & str ,
209210 cost_map : & HashMap < String , Option < StaticCost > > ,
@@ -221,18 +222,159 @@ fn static_cost_native(
221222 Ok ( summing_cost. into ( ) )
222223}
223224
225+ type MinMaxTraitCount = ( u64 , u64 ) ;
226+ type TraitCount = HashMap < String , MinMaxTraitCount > ;
227+
228+ // // "<trait-name>" -> "trait-name"
229+ // // ClarityName can't contain
230+ // fn strip_trait_surrounding_brackets(name: &ClarityName) -> ClarityName {
231+ // let stripped = name
232+ // .as_str()
233+ // .strip_prefix("<")
234+ // .and_then(|name| name.strip_suffix(">"));
235+ // if let Some(name) = stripped {
236+ // ClarityName::from(name)
237+ // } else {
238+ // name.clone()
239+ // }
240+ // }
241+ fn get_trait_count ( costs : & HashMap < String , CostAnalysisNode > ) -> Option < TraitCount > {
242+ let mut trait_counts = HashMap :: new ( ) ;
243+ let mut trait_names = HashMap :: new ( ) ;
244+ // walk tree
245+ for ( name, cost_analysis_node) in costs. iter ( ) {
246+ get_trait_count_from_node (
247+ cost_analysis_node,
248+ & mut trait_counts,
249+ & mut trait_names,
250+ name. clone ( ) ,
251+ 1 ,
252+ ) ;
253+ // trait_counts.extend(counts);
254+ }
255+ Some ( trait_counts)
256+ }
257+ fn get_trait_count_from_node (
258+ cost_analysis_node : & CostAnalysisNode ,
259+ mut trait_counts : & mut TraitCount ,
260+ mut trait_names : & mut HashMap < ClarityName , String > ,
261+ containing_fn_name : String ,
262+ multiplier : u64 ,
263+ ) -> TraitCount {
264+ match & cost_analysis_node. expr {
265+ CostExprNode :: UserArgument ( arg_name, arg_type) => match arg_type {
266+ SymbolicExpressionType :: TraitReference ( name, _) => {
267+ trait_names. insert ( arg_name. clone ( ) , name. clone ( ) . to_string ( ) ) ;
268+ trait_counts. entry ( name. to_string ( ) ) . or_insert ( ( 0 , 0 ) ) ;
269+ }
270+ _ => { }
271+ } ,
272+ CostExprNode :: NativeFunction ( native_function) => {
273+ println ! ( "native function: {:?}" , native_function) ;
274+ match native_function {
275+ // if map, filter, or fold, we need to check if traits are called
276+ NativeFunctions :: Map | NativeFunctions :: Filter | NativeFunctions :: Fold => {
277+ println ! ( "map: {:?}" , cost_analysis_node. children) ;
278+ let list_to_traverse = cost_analysis_node. children [ 1 ] . clone ( ) ;
279+ let multiplier = match list_to_traverse. expr {
280+ CostExprNode :: UserArgument ( _, arg_type) => match arg_type {
281+ SymbolicExpressionType :: List ( list) => {
282+ if list[ 0 ] . match_atom ( ) . unwrap ( ) . as_str ( ) == "list" {
283+ match list[ 1 ] . clone ( ) . expr {
284+ SymbolicExpressionType :: LiteralValue ( value) => {
285+ match value {
286+ Value :: Int ( value) => value as u64 ,
287+ _ => 1 ,
288+ }
289+ }
290+ _ => 1 ,
291+ }
292+ } else {
293+ 1
294+ }
295+ }
296+ _ => 1 ,
297+ } ,
298+ _ => 1 ,
299+ } ;
300+ println ! ( "multiplier: {:?}" , multiplier) ;
301+ cost_analysis_node. children . iter ( ) . for_each ( |child| {
302+ get_trait_count_from_node (
303+ child,
304+ & mut trait_counts,
305+ & mut trait_names,
306+ containing_fn_name. clone ( ) ,
307+ multiplier,
308+ ) ;
309+ } ) ;
310+ }
311+ _ => { }
312+ }
313+ }
314+ CostExprNode :: AtomValue ( atom_value) => {
315+ println ! ( "atom value: {:?}" , atom_value) ;
316+ // do nothing
317+ }
318+ CostExprNode :: Atom ( atom) => {
319+ println ! ( "atom: {:?}" , atom) ;
320+ if trait_names. get ( atom) . is_some ( ) {
321+ trait_counts
322+ . entry ( containing_fn_name. clone ( ) )
323+ . and_modify ( |( min, max) | {
324+ * min += 1 ;
325+ * max += multiplier;
326+ } )
327+ . or_insert ( ( 1 , multiplier) ) ;
328+ }
329+ // do nothing
330+ }
331+ CostExprNode :: FieldIdentifier ( field_identifier) => {
332+ println ! ( "field identifier: {:?}" , field_identifier) ;
333+ // do nothing
334+ }
335+ CostExprNode :: TraitReference ( trait_name) => {
336+ println ! ( "trait_name: {:?}" , trait_name) ;
337+ trait_counts
338+ . entry ( trait_name. to_string ( ) )
339+ . and_modify ( |( min, max) | {
340+ * min += 1 ;
341+ * max += multiplier;
342+ } )
343+ . or_insert ( ( 1 , multiplier) ) ;
344+ }
345+ CostExprNode :: UserFunction ( user_function) => {
346+ println ! ( "user function: {:?}" , user_function) ;
347+ cost_analysis_node. children . iter ( ) . for_each ( |child| {
348+ get_trait_count_from_node (
349+ child,
350+ & mut trait_counts,
351+ & mut trait_names,
352+ containing_fn_name. clone ( ) ,
353+ multiplier,
354+ ) ;
355+ } ) ;
356+ }
357+ }
358+ trait_counts. clone ( )
359+ }
360+
224361pub fn static_cost_from_ast (
225362 contract_ast : & crate :: vm:: ast:: ContractAST ,
226363 clarity_version : & ClarityVersion ,
227- ) -> Result < HashMap < String , StaticCost > , String > {
364+ ) -> Result < HashMap < String , ( StaticCost , Option < TraitCount > ) > , String > {
228365 let cost_trees = static_cost_tree_from_ast ( contract_ast, clarity_version) ?;
229366
230- Ok ( cost_trees
367+ let trait_count = get_trait_count ( & cost_trees) ;
368+ let costs: HashMap < String , StaticCost > = cost_trees
231369 . into_iter ( )
232370 . map ( |( name, cost_analysis_node) | {
233371 let summing_cost = calculate_total_cost_with_branching ( & cost_analysis_node) ;
234372 ( name, summing_cost. into ( ) )
235373 } )
374+ . collect ( ) ;
375+ Ok ( costs
376+ . into_iter ( )
377+ . map ( |( name, cost) | ( name, ( cost, trait_count. clone ( ) ) ) )
236378 . collect ( ) )
237379}
238380
@@ -288,7 +430,11 @@ pub fn static_cost(
288430 let epoch = env. global_context . epoch_id ;
289431 let ast = make_ast ( & contract_source, epoch, clarity_version) ?;
290432
291- static_cost_from_ast ( & ast, clarity_version)
433+ let costs = static_cost_from_ast ( & ast, clarity_version) ?;
434+ Ok ( costs
435+ . into_iter ( )
436+ . map ( |( name, ( cost, _trait_count) ) | ( name, cost) )
437+ . collect ( ) )
292438}
293439
294440/// same idea as `static_cost` but returns the root of the cost analysis tree for each function
@@ -447,23 +593,31 @@ fn build_function_definition_cost_analysis_tree(
447593 . match_atom ( )
448594 . ok_or ( "Expected atom for argument name" ) ?;
449595
450- let arg_type = match & arg_list[ 1 ] . expr {
451- SymbolicExpressionType :: Atom ( type_name) => type_name. clone ( ) ,
452- SymbolicExpressionType :: AtomValue ( value) => {
453- ClarityName :: from ( value. to_string ( ) . as_str ( ) )
454- }
455- SymbolicExpressionType :: LiteralValue ( value) => {
456- ClarityName :: from ( value. to_string ( ) . as_str ( ) )
457- }
458- _ => return Err ( "Argument type must be an atom or atom value" . to_string ( ) ) ,
459- } ;
596+ let arg_type = arg_list[ 1 ] . clone ( ) ;
597+ // let arg_type = match &arg_list[1].expr {
598+ // SymbolicExpressionType::Atom(type_name) => type_name.clone(),
599+ // SymbolicExpressionType::AtomValue(value) => {
600+ // ClarityName::from(value.to_string().as_str())
601+ // }
602+ // SymbolicExpressionType::LiteralValue(value) => {
603+ // ClarityName::from(value.to_string().as_str())
604+ // }
605+ // SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => {
606+ // trait_name.clone()
607+ // }
608+ // SymbolicExpressionType::List(_) => ClarityName::from("list"),
609+ // _ => {
610+ // println!("arg: {:?}", arg_list[1].expr);
611+ // return Err("Argument type must be an atom or atom value".to_string());
612+ // }
613+ // };
460614
461615 // Add to function's user arguments context
462- function_user_args. add_argument ( arg_name. clone ( ) , arg_type. clone ( ) ) ;
616+ function_user_args. add_argument ( arg_name. clone ( ) , arg_type. clone ( ) . expr ) ;
463617
464618 // Create UserArgument node
465619 children. push ( CostAnalysisNode :: leaf (
466- CostExprNode :: UserArgument ( arg_name. clone ( ) , arg_type) ,
620+ CostExprNode :: UserArgument ( arg_name. clone ( ) , arg_type. clone ( ) . expr ) ,
467621 StaticCost :: ZERO ,
468622 ) ) ;
469623 }
@@ -757,6 +911,7 @@ fn get_cost_function_for_native(
757911 InsertEntry => Some ( Costs3 :: cost_set_entry) ,
758912 DeleteEntry => Some ( Costs3 :: cost_set_entry) ,
759913 StxBurn => Some ( Costs3 :: cost_stx_transfer) ,
914+ Secp256r1Verify => Some ( Costs3 :: cost_secp256r1verify) ,
760915 RestrictAssets => None , // TODO: add cost function
761916 AllowanceWithStx => None , // TODO: add cost function
762917 AllowanceWithFt => None , // TODO: add cost function
@@ -871,7 +1026,11 @@ mod tests {
8711026 ) -> Result < HashMap < String , StaticCost > , String > {
8721027 let epoch = StacksEpochId :: latest ( ) ;
8731028 let ast = make_ast ( source, epoch, clarity_version) ?;
874- static_cost_from_ast ( & ast, clarity_version)
1029+ let costs = static_cost_from_ast ( & ast, clarity_version) ?;
1030+ Ok ( costs
1031+ . into_iter ( )
1032+ . map ( |( name, ( cost, _trait_count) ) | ( name, cost) )
1033+ . collect ( ) )
8751034 }
8761035
8771036 fn build_test_ast ( src : & str ) -> crate :: vm:: ast:: ContractAST {
@@ -1081,15 +1240,15 @@ mod tests {
10811240 assert ! ( matches!( user_arg_x. expr, CostExprNode :: UserArgument ( _, _) ) ) ;
10821241 if let CostExprNode :: UserArgument ( arg_name, arg_type) = & user_arg_x. expr {
10831242 assert_eq ! ( arg_name. as_str( ) , "x" ) ;
1084- assert_eq ! ( arg_type . as_str ( ) , "uint" ) ;
1243+ assert ! ( matches! ( arg_type , SymbolicExpressionType :: Atom ( _ ) ) ) ;
10851244 }
10861245
10871246 // Second child should be UserArgument for (y u64)
10881247 let user_arg_y = & cost_tree. children [ 1 ] ;
10891248 assert ! ( matches!( user_arg_y. expr, CostExprNode :: UserArgument ( _, _) ) ) ;
10901249 if let CostExprNode :: UserArgument ( arg_name, arg_type) = & user_arg_y. expr {
10911250 assert_eq ! ( arg_name. as_str( ) , "y" ) ;
1092- assert_eq ! ( arg_type . as_str ( ) , "uint" ) ;
1251+ assert ! ( matches! ( arg_type , SymbolicExpressionType :: Atom ( _ ) ) ) ;
10931252 }
10941253
10951254 // Third child should be the function body (+ x y)
@@ -1108,11 +1267,11 @@ mod tests {
11081267
11091268 if let CostExprNode :: UserArgument ( name, arg_type) = & arg_x_ref. expr {
11101269 assert_eq ! ( name. as_str( ) , "x" ) ;
1111- assert_eq ! ( arg_type . as_str ( ) , "uint" ) ;
1270+ assert ! ( matches! ( arg_type , SymbolicExpressionType :: Atom ( _ ) ) ) ;
11121271 }
11131272 if let CostExprNode :: UserArgument ( name, arg_type) = & arg_y_ref. expr {
11141273 assert_eq ! ( name. as_str( ) , "y" ) ;
1115- assert_eq ! ( arg_type . as_str ( ) , "uint" ) ;
1274+ assert ! ( matches! ( arg_type , SymbolicExpressionType :: Atom ( _ ) ) ) ;
11161275 }
11171276 }
11181277
0 commit comments