Skip to content

Commit 1e75681

Browse files
committed
2-pass for dependent function costs
1 parent b933704 commit 1e75681

File tree

2 files changed

+94
-27
lines changed

2 files changed

+94
-27
lines changed

clarity/src/vm/costs/analysis.rs

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ fn make_ast(
207207
/// somewhat of a passthrough since we don't have to build the whole context we can jsut return the cost of the single expression
208208
fn static_cost_native(
209209
source: &str,
210-
cost_map: &HashMap<String, StaticCost>,
210+
cost_map: &HashMap<String, Option<StaticCost>>,
211211
clarity_version: &ClarityVersion,
212212
) -> Result<StaticCost, String> {
213213
let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version
@@ -234,26 +234,53 @@ pub fn static_cost(
234234
}
235235
let exprs = &ast.expressions;
236236
let user_args = UserArgumentsContext::new();
237-
let mut costs = HashMap::new();
237+
let mut costs: HashMap<String, Option<StaticCost>> = HashMap::new();
238+
239+
// First pass registers all function definitions
240+
for expr in exprs {
241+
if let Some(function_name) = extract_function_name(expr) {
242+
costs.insert(function_name, None);
243+
}
244+
}
245+
246+
// Second pass computes costs
238247
for expr in exprs {
239-
let (_, cost_analysis_tree) =
240-
build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?;
241-
242-
let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree);
243-
costs.insert(
244-
expr.match_atom()
245-
.map(|name| name.to_string())
246-
.unwrap_or_default(),
247-
summing_cost.into(),
248-
);
249-
}
250-
Ok(costs)
248+
if let Some(function_name) = extract_function_name(expr) {
249+
let (_, cost_analysis_tree) =
250+
build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?;
251+
252+
let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree);
253+
costs.insert(function_name, Some(summing_cost.into()));
254+
}
255+
}
256+
257+
Ok(costs
258+
.into_iter()
259+
.filter_map(|(name, cost)| cost.map(|c| (name, c)))
260+
.collect())
261+
}
262+
263+
/// Extract function name from a symbolic expression
264+
fn extract_function_name(expr: &SymbolicExpression) -> Option<String> {
265+
if let Some(list) = expr.match_list() {
266+
if let Some(first_atom) = list.first().and_then(|first| first.match_atom()) {
267+
if is_function_definition(first_atom.as_str()) {
268+
if let Some(signature) = list.get(1).and_then(|sig| sig.match_list()) {
269+
return signature
270+
.first()
271+
.and_then(|name| name.match_atom())
272+
.map(|name| name.to_string());
273+
}
274+
}
275+
}
276+
}
277+
None
251278
}
252279

253280
pub fn build_cost_analysis_tree(
254281
expr: &SymbolicExpression,
255282
user_args: &UserArgumentsContext,
256-
cost_map: &HashMap<String, StaticCost>,
283+
cost_map: &HashMap<String, Option<StaticCost>>,
257284
clarity_version: &ClarityVersion,
258285
) -> Result<(Option<String>, CostAnalysisNode), String> {
259286
match &expr.expr {
@@ -339,7 +366,7 @@ fn parse_atom_expression(
339366
fn build_function_definition_cost_analysis_tree(
340367
list: &[SymbolicExpression],
341368
_user_args: &UserArgumentsContext,
342-
cost_map: &HashMap<String, StaticCost>,
369+
cost_map: &HashMap<String, Option<StaticCost>>,
343370
clarity_version: &ClarityVersion,
344371
) -> Result<(String, CostAnalysisNode), String> {
345372
let define_type = list[0]
@@ -417,7 +444,7 @@ fn get_function_name(expr: &SymbolicExpression) -> Result<ClarityName, String> {
417444
fn build_listlike_cost_analysis_tree(
418445
exprs: &[SymbolicExpression],
419446
user_args: &UserArgumentsContext,
420-
cost_map: &HashMap<String, StaticCost>,
447+
cost_map: &HashMap<String, Option<StaticCost>>,
421448
clarity_version: &ClarityVersion,
422449
) -> Result<CostAnalysisNode, String> {
423450
let mut children = Vec::new();
@@ -442,8 +469,9 @@ fn build_listlike_cost_analysis_tree(
442469
(CostExprNode::NativeFunction(native_function), cost)
443470
} else {
444471
// If not a native function, treat as user-defined function and look it up
472+
println!("in user-defined function");
445473
let expr_node = CostExprNode::UserFunction(function_name.clone());
446-
let cost = calculate_function_cost(function_name.to_string(), cost_map)?;
474+
let cost = calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?;
447475
(expr_node, cost)
448476
};
449477

@@ -459,14 +487,31 @@ fn build_listlike_cost_analysis_tree(
459487
Ok(CostAnalysisNode::new(expr_node, cost, children))
460488
}
461489

462-
// this is a bit tricky, we need to ensure the previously defined function is
463-
// within the cost_map already or we need to find it and compute the cost first
490+
// Calculate function cost with lazy evaluation support
464491
fn calculate_function_cost(
465492
function_name: String,
466-
cost_map: &HashMap<String, StaticCost>,
493+
cost_map: &HashMap<String, Option<StaticCost>>,
494+
_clarity_version: &ClarityVersion,
467495
) -> Result<StaticCost, String> {
468-
let cost = cost_map.get(&function_name).unwrap_or(&StaticCost::ZERO);
469-
Ok(cost.clone())
496+
match cost_map.get(&function_name) {
497+
Some(Some(cost)) => {
498+
// Cost already computed
499+
Ok(cost.clone())
500+
}
501+
Some(None) => {
502+
// Function exists but cost not yet computed - this indicates a circular dependency
503+
// For now, return zero cost to avoid infinite recursion
504+
println!(
505+
"Circular dependency detected for function: {}",
506+
function_name
507+
);
508+
Ok(StaticCost::ZERO)
509+
}
510+
None => {
511+
// Function not found
512+
Ok(StaticCost::ZERO)
513+
}
514+
}
470515
}
471516
/// This function is no longer needed - we now use NativeFunctions::lookup_by_name_at_version
472517
/// directly in build_listlike_cost_analysis_tree
@@ -760,7 +805,7 @@ mod tests {
760805
source: &str,
761806
clarity_version: &ClarityVersion,
762807
) -> Result<StaticCost, String> {
763-
let cost_map = HashMap::new();
808+
let cost_map: HashMap<String, Option<StaticCost>> = HashMap::new();
764809
static_cost_native(source, &cost_map, clarity_version)
765810
}
766811

clarity/src/vm/tests/analysis.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,18 @@ fn test_complex_trait_implementation_costs(
181181

182182
#[test]
183183
fn test_build_cost_analysis_tree_function_definition() {
184-
let source = r#"(define-public (somefunc (a uint))
184+
let src = r#"(define-public (somefunc (a uint))
185185
(ok (+ a 1))
186186
)"#;
187187

188188
let contract_id = QualifiedContractIdentifier::transient();
189189
let ast = ast::parse(
190190
&contract_id,
191-
source,
191+
src,
192192
ClarityVersion::Clarity3,
193193
StacksEpochId::Epoch32,
194194
)
195-
.expect("Failed to parse source code");
195+
.expect("Failed to parse");
196196

197197
let expr = &ast[0];
198198
let user_args = UserArgumentsContext::new();
@@ -214,3 +214,25 @@ fn test_build_cost_analysis_tree_function_definition() {
214214
}
215215
}
216216
}
217+
218+
#[test]
219+
fn test_dependent_function_calls() {
220+
let src = r#"(define-public (add-one (a uint))
221+
(begin
222+
(print "somefunc")
223+
(somefunc a)
224+
)
225+
)
226+
(define-private (somefunc (a uint))
227+
(ok (+ a 1))
228+
)"#;
229+
230+
let contract_id = QualifiedContractIdentifier::transient();
231+
let function_map = static_cost(src, &ClarityVersion::Clarity3).unwrap();
232+
233+
let add_one_cost = function_map.get("add-one").unwrap();
234+
let somefunc_cost = function_map.get("somefunc").unwrap();
235+
236+
assert!(add_one_cost.min.runtime >= somefunc_cost.min.runtime);
237+
assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime);
238+
}

0 commit comments

Comments
 (0)