Skip to content

Commit 530e666

Browse files
committed
attempt at trait counting walk
1 parent 6ac7b61 commit 530e666

File tree

2 files changed

+238
-31
lines changed

2 files changed

+238
-31
lines changed

clarity/src/vm/costs/analysis.rs

Lines changed: 186 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
// Static cost analysis for Clarity expressions
1+
// Static cost analysis for Clarity contracts
22

3-
use std::collections::HashMap;
3+
use std::collections::{HashMap, HashSet};
44

55
use clarity_types::types::{CharType, SequenceData, TraitIdentifier};
66
use stacks_common::types::StacksEpochId;
@@ -48,7 +48,7 @@ pub enum CostExprNode {
4848
FieldIdentifier(TraitIdentifier),
4949
TraitReference(ClarityName),
5050
// User function arguments
51-
UserArgument(ClarityName, ClarityName), // (argument_name, argument_type)
51+
UserArgument(ClarityName, SymbolicExpressionType), // (argument_name, argument_type)
5252
// User-defined functions
5353
UserFunction(ClarityName),
5454
}
@@ -94,7 +94,7 @@ impl StaticCost {
9494
#[derive(Debug, Clone)]
9595
pub struct UserArgumentsContext {
9696
/// Map from argument name to argument type
97-
pub arguments: HashMap<ClarityName, ClarityName>,
97+
pub arguments: HashMap<ClarityName, SymbolicExpressionType>,
9898
}
9999

100100
impl UserArgumentsContext {
@@ -104,15 +104,15 @@ impl UserArgumentsContext {
104104
}
105105
}
106106

107-
pub fn add_argument(&mut self, name: ClarityName, arg_type: ClarityName) {
107+
pub fn add_argument(&mut self, name: ClarityName, arg_type: SymbolicExpressionType) {
108108
self.arguments.insert(name, arg_type);
109109
}
110110

111111
pub fn is_user_argument(&self, name: &ClarityName) -> bool {
112112
self.arguments.contains_key(name)
113113
}
114114

115-
pub fn get_argument_type(&self, name: &ClarityName) -> Option<&ClarityName> {
115+
pub fn get_argument_type(&self, name: &ClarityName) -> Option<&SymbolicExpressionType> {
116116
self.arguments.get(name)
117117
}
118118
}
@@ -203,7 +203,8 @@ fn make_ast(
203203
Ok(ast)
204204
}
205205

206-
/// somewhat of a passthrough since we don't have to build the whole context we can jsut return the cost of the single expression
206+
/// somewhat of a passthrough since we don't have to build the whole context we
207+
/// can jsut return the cost of the single expression
207208
fn static_cost_native(
208209
source: &str,
209210
cost_map: &HashMap<String, Option<StaticCost>>,
@@ -221,18 +222,159 @@ fn static_cost_native(
221222
Ok(summing_cost.into())
222223
}
223224

225+
type MinMaxTraitCount = (u64, u64);
226+
type TraitCount = HashMap<String, MinMaxTraitCount>;
227+
228+
// // "<trait-name>" -> "trait-name"
229+
// // ClarityName can't contain
230+
// fn strip_trait_surrounding_brackets(name: &ClarityName) -> ClarityName {
231+
// let stripped = name
232+
// .as_str()
233+
// .strip_prefix("<")
234+
// .and_then(|name| name.strip_suffix(">"));
235+
// if let Some(name) = stripped {
236+
// ClarityName::from(name)
237+
// } else {
238+
// name.clone()
239+
// }
240+
// }
241+
fn get_trait_count(costs: &HashMap<String, CostAnalysisNode>) -> Option<TraitCount> {
242+
let mut trait_counts = HashMap::new();
243+
let mut trait_names = HashMap::new();
244+
// walk tree
245+
for (name, cost_analysis_node) in costs.iter() {
246+
get_trait_count_from_node(
247+
cost_analysis_node,
248+
&mut trait_counts,
249+
&mut trait_names,
250+
name.clone(),
251+
1,
252+
);
253+
// trait_counts.extend(counts);
254+
}
255+
Some(trait_counts)
256+
}
257+
fn get_trait_count_from_node(
258+
cost_analysis_node: &CostAnalysisNode,
259+
mut trait_counts: &mut TraitCount,
260+
mut trait_names: &mut HashMap<ClarityName, String>,
261+
containing_fn_name: String,
262+
multiplier: u64,
263+
) -> TraitCount {
264+
match &cost_analysis_node.expr {
265+
CostExprNode::UserArgument(arg_name, arg_type) => match arg_type {
266+
SymbolicExpressionType::TraitReference(name, _) => {
267+
trait_names.insert(arg_name.clone(), name.clone().to_string());
268+
trait_counts.entry(name.to_string()).or_insert((0, 0));
269+
}
270+
_ => {}
271+
},
272+
CostExprNode::NativeFunction(native_function) => {
273+
println!("native function: {:?}", native_function);
274+
match native_function {
275+
// if map, filter, or fold, we need to check if traits are called
276+
NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => {
277+
println!("map: {:?}", cost_analysis_node.children);
278+
let list_to_traverse = cost_analysis_node.children[1].clone();
279+
let multiplier = match list_to_traverse.expr {
280+
CostExprNode::UserArgument(_, arg_type) => match arg_type {
281+
SymbolicExpressionType::List(list) => {
282+
if list[0].match_atom().unwrap().as_str() == "list" {
283+
match list[1].clone().expr {
284+
SymbolicExpressionType::LiteralValue(value) => {
285+
match value {
286+
Value::Int(value) => value as u64,
287+
_ => 1,
288+
}
289+
}
290+
_ => 1,
291+
}
292+
} else {
293+
1
294+
}
295+
}
296+
_ => 1,
297+
},
298+
_ => 1,
299+
};
300+
println!("multiplier: {:?}", multiplier);
301+
cost_analysis_node.children.iter().for_each(|child| {
302+
get_trait_count_from_node(
303+
child,
304+
&mut trait_counts,
305+
&mut trait_names,
306+
containing_fn_name.clone(),
307+
multiplier,
308+
);
309+
});
310+
}
311+
_ => {}
312+
}
313+
}
314+
CostExprNode::AtomValue(atom_value) => {
315+
println!("atom value: {:?}", atom_value);
316+
// do nothing
317+
}
318+
CostExprNode::Atom(atom) => {
319+
println!("atom: {:?}", atom);
320+
if trait_names.get(atom).is_some() {
321+
trait_counts
322+
.entry(containing_fn_name.clone())
323+
.and_modify(|(min, max)| {
324+
*min += 1;
325+
*max += multiplier;
326+
})
327+
.or_insert((1, multiplier));
328+
}
329+
// do nothing
330+
}
331+
CostExprNode::FieldIdentifier(field_identifier) => {
332+
println!("field identifier: {:?}", field_identifier);
333+
// do nothing
334+
}
335+
CostExprNode::TraitReference(trait_name) => {
336+
println!("trait_name: {:?}", trait_name);
337+
trait_counts
338+
.entry(trait_name.to_string())
339+
.and_modify(|(min, max)| {
340+
*min += 1;
341+
*max += multiplier;
342+
})
343+
.or_insert((1, multiplier));
344+
}
345+
CostExprNode::UserFunction(user_function) => {
346+
println!("user function: {:?}", user_function);
347+
cost_analysis_node.children.iter().for_each(|child| {
348+
get_trait_count_from_node(
349+
child,
350+
&mut trait_counts,
351+
&mut trait_names,
352+
containing_fn_name.clone(),
353+
multiplier,
354+
);
355+
});
356+
}
357+
}
358+
trait_counts.clone()
359+
}
360+
224361
pub fn static_cost_from_ast(
225362
contract_ast: &crate::vm::ast::ContractAST,
226363
clarity_version: &ClarityVersion,
227-
) -> Result<HashMap<String, StaticCost>, String> {
364+
) -> Result<HashMap<String, (StaticCost, Option<TraitCount>)>, String> {
228365
let cost_trees = static_cost_tree_from_ast(contract_ast, clarity_version)?;
229366

230-
Ok(cost_trees
367+
let trait_count = get_trait_count(&cost_trees);
368+
let costs: HashMap<String, StaticCost> = cost_trees
231369
.into_iter()
232370
.map(|(name, cost_analysis_node)| {
233371
let summing_cost = calculate_total_cost_with_branching(&cost_analysis_node);
234372
(name, summing_cost.into())
235373
})
374+
.collect();
375+
Ok(costs
376+
.into_iter()
377+
.map(|(name, cost)| (name, (cost, trait_count.clone())))
236378
.collect())
237379
}
238380

@@ -288,7 +430,11 @@ pub fn static_cost(
288430
let epoch = env.global_context.epoch_id;
289431
let ast = make_ast(&contract_source, epoch, clarity_version)?;
290432

291-
static_cost_from_ast(&ast, clarity_version)
433+
let costs = static_cost_from_ast(&ast, clarity_version)?;
434+
Ok(costs
435+
.into_iter()
436+
.map(|(name, (cost, _trait_count))| (name, cost))
437+
.collect())
292438
}
293439

294440
/// same idea as `static_cost` but returns the root of the cost analysis tree for each function
@@ -447,23 +593,31 @@ fn build_function_definition_cost_analysis_tree(
447593
.match_atom()
448594
.ok_or("Expected atom for argument name")?;
449595

450-
let arg_type = match &arg_list[1].expr {
451-
SymbolicExpressionType::Atom(type_name) => type_name.clone(),
452-
SymbolicExpressionType::AtomValue(value) => {
453-
ClarityName::from(value.to_string().as_str())
454-
}
455-
SymbolicExpressionType::LiteralValue(value) => {
456-
ClarityName::from(value.to_string().as_str())
457-
}
458-
_ => return Err("Argument type must be an atom or atom value".to_string()),
459-
};
596+
let arg_type = arg_list[1].clone();
597+
// let arg_type = match &arg_list[1].expr {
598+
// SymbolicExpressionType::Atom(type_name) => type_name.clone(),
599+
// SymbolicExpressionType::AtomValue(value) => {
600+
// ClarityName::from(value.to_string().as_str())
601+
// }
602+
// SymbolicExpressionType::LiteralValue(value) => {
603+
// ClarityName::from(value.to_string().as_str())
604+
// }
605+
// SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => {
606+
// trait_name.clone()
607+
// }
608+
// SymbolicExpressionType::List(_) => ClarityName::from("list"),
609+
// _ => {
610+
// println!("arg: {:?}", arg_list[1].expr);
611+
// return Err("Argument type must be an atom or atom value".to_string());
612+
// }
613+
// };
460614

461615
// Add to function's user arguments context
462-
function_user_args.add_argument(arg_name.clone(), arg_type.clone());
616+
function_user_args.add_argument(arg_name.clone(), arg_type.clone().expr);
463617

464618
// Create UserArgument node
465619
children.push(CostAnalysisNode::leaf(
466-
CostExprNode::UserArgument(arg_name.clone(), arg_type),
620+
CostExprNode::UserArgument(arg_name.clone(), arg_type.clone().expr),
467621
StaticCost::ZERO,
468622
));
469623
}
@@ -757,6 +911,7 @@ fn get_cost_function_for_native(
757911
InsertEntry => Some(Costs3::cost_set_entry),
758912
DeleteEntry => Some(Costs3::cost_set_entry),
759913
StxBurn => Some(Costs3::cost_stx_transfer),
914+
Secp256r1Verify => Some(Costs3::cost_secp256r1verify),
760915
RestrictAssets => None, // TODO: add cost function
761916
AllowanceWithStx => None, // TODO: add cost function
762917
AllowanceWithFt => None, // TODO: add cost function
@@ -871,7 +1026,11 @@ mod tests {
8711026
) -> Result<HashMap<String, StaticCost>, String> {
8721027
let epoch = StacksEpochId::latest();
8731028
let ast = make_ast(source, epoch, clarity_version)?;
874-
static_cost_from_ast(&ast, clarity_version)
1029+
let costs = static_cost_from_ast(&ast, clarity_version)?;
1030+
Ok(costs
1031+
.into_iter()
1032+
.map(|(name, (cost, _trait_count))| (name, cost))
1033+
.collect())
8751034
}
8761035

8771036
fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST {
@@ -1081,15 +1240,15 @@ mod tests {
10811240
assert!(matches!(user_arg_x.expr, CostExprNode::UserArgument(_, _)));
10821241
if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr {
10831242
assert_eq!(arg_name.as_str(), "x");
1084-
assert_eq!(arg_type.as_str(), "uint");
1243+
assert!(matches!(arg_type, SymbolicExpressionType::Atom(_)));
10851244
}
10861245

10871246
// Second child should be UserArgument for (y u64)
10881247
let user_arg_y = &cost_tree.children[1];
10891248
assert!(matches!(user_arg_y.expr, CostExprNode::UserArgument(_, _)));
10901249
if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr {
10911250
assert_eq!(arg_name.as_str(), "y");
1092-
assert_eq!(arg_type.as_str(), "uint");
1251+
assert!(matches!(arg_type, SymbolicExpressionType::Atom(_)));
10931252
}
10941253

10951254
// Third child should be the function body (+ x y)
@@ -1108,11 +1267,11 @@ mod tests {
11081267

11091268
if let CostExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr {
11101269
assert_eq!(name.as_str(), "x");
1111-
assert_eq!(arg_type.as_str(), "uint");
1270+
assert!(matches!(arg_type, SymbolicExpressionType::Atom(_)));
11121271
}
11131272
if let CostExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr {
11141273
assert_eq!(name.as_str(), "y");
1115-
assert_eq!(arg_type.as_str(), "uint");
1274+
assert!(matches!(arg_type, SymbolicExpressionType::Atom(_)));
11161275
}
11171276
}
11181277

0 commit comments

Comments
 (0)