Skip to content

Commit 6ac7b61

Browse files
committed
return the cost analysis node root in static_cost_tree
1 parent fc009fe commit 6ac7b61

File tree

2 files changed

+53
-36
lines changed

2 files changed

+53
-36
lines changed

clarity/src/vm/costs/analysis.rs

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ use crate::vm::{ClarityVersion, Value};
2121
// type-checking
2222
// lookups
2323
// unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59)
24-
// possibly use ContractContext? Enviornment? we need to use this somehow to
25-
// provide full view of a contract, rather than passing in source
2624

2725
const STRING_COST_BASE: u64 = 36;
2826
const STRING_COST_MULTIPLIER: u64 = 3;
@@ -227,53 +225,58 @@ pub fn static_cost_from_ast(
227225
contract_ast: &crate::vm::ast::ContractAST,
228226
clarity_version: &ClarityVersion,
229227
) -> Result<HashMap<String, StaticCost>, String> {
230-
let exprs = &contract_ast.expressions;
228+
let cost_trees = static_cost_tree_from_ast(contract_ast, clarity_version)?;
231229

232-
if exprs.is_empty() {
233-
return Err("No expressions found in contract AST".to_string());
234-
}
230+
Ok(cost_trees
231+
.into_iter()
232+
.map(|(name, cost_analysis_node)| {
233+
let summing_cost = calculate_total_cost_with_branching(&cost_analysis_node);
234+
(name, summing_cost.into())
235+
})
236+
.collect())
237+
}
235238

239+
fn static_cost_tree_from_ast(
240+
ast: &crate::vm::ast::ContractAST,
241+
clarity_version: &ClarityVersion,
242+
) -> Result<HashMap<String, CostAnalysisNode>, String> {
243+
let exprs = &ast.expressions;
236244
let user_args = UserArgumentsContext::new();
237-
let mut costs: HashMap<String, Option<StaticCost>> = HashMap::new();
238-
239-
// First pass registers all function definitions
245+
let costs_map: HashMap<String, Option<StaticCost>> = HashMap::new();
246+
let mut costs: HashMap<String, Option<CostAnalysisNode>> = HashMap::new();
240247
for expr in exprs {
241248
if let Some(function_name) = extract_function_name(expr) {
242249
costs.insert(function_name, None);
243250
}
244251
}
245-
246-
// Second pass computes costs
247252
for expr in exprs {
248253
if let Some(function_name) = extract_function_name(expr) {
249254
let (_, cost_analysis_tree) =
250-
build_cost_analysis_tree(expr, &user_args, &costs, clarity_version)?;
251-
252-
let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree);
253-
costs.insert(function_name, Some(summing_cost.into()));
255+
build_cost_analysis_tree(expr, &user_args, &costs_map, clarity_version)?;
256+
costs.insert(function_name, Some(cost_analysis_tree));
254257
}
255258
}
256-
257259
Ok(costs
258260
.into_iter()
259261
.filter_map(|(name, cost)| cost.map(|c| (name, c)))
260262
.collect())
261263
}
262264

263265
/// Calculate static execution cost for functions using Environment context
264-
/// This replaces the old source-string based approach with Environment integration
266+
/// returns the top level cost for specific functions
267+
/// function_name -> cost
265268
pub fn static_cost(
266269
env: &mut Environment,
267270
contract_identifier: &QualifiedContractIdentifier,
268271
) -> Result<HashMap<String, StaticCost>, String> {
269-
// Get the contract source from the environment's database
272+
// Get contract source from the environment's database
270273
let contract_source = env
271274
.global_context
272275
.database
273276
.get_contract_src(contract_identifier)
274277
.ok_or_else(|| "Contract source not found in database".to_string())?;
275278

276-
// Get the contract's clarity version from the environment
279+
// Get clarity version from the environment
277280
let contract = env
278281
.global_context
279282
.database
@@ -288,11 +291,32 @@ pub fn static_cost(
288291
static_cost_from_ast(&ast, clarity_version)
289292
}
290293

291-
// pub fn static_cost_tree(
292-
// source: &str,
293-
// clarity_version: &ClarityVersion,
294-
// ) -> Result<HashMap<String, CostAnalysisNode>, String> {
295-
// }
294+
/// same idea as `static_cost` but returns the root of the cost analysis tree for each function
295+
pub fn static_cost_tree(
296+
env: &mut Environment,
297+
contract_identifier: &QualifiedContractIdentifier,
298+
) -> Result<HashMap<String, CostAnalysisNode>, String> {
299+
// Get contract source from the environment's database
300+
let contract_source = env
301+
.global_context
302+
.database
303+
.get_contract_src(contract_identifier)
304+
.ok_or_else(|| "Contract source not found in database".to_string())?;
305+
306+
// Get clarity version from the environment
307+
let contract = env
308+
.global_context
309+
.database
310+
.get_contract(contract_identifier)
311+
.map_err(|e| format!("Failed to get contract: {:?}", e))?;
312+
313+
let clarity_version = contract.contract_context.get_clarity_version();
314+
315+
let epoch = env.global_context.epoch_id;
316+
let ast = make_ast(&contract_source, epoch, clarity_version)?;
317+
318+
static_cost_tree_from_ast(&ast, clarity_version)
319+
}
296320

297321
/// Extract function name from a symbolic expression
298322
fn extract_function_name(expr: &SymbolicExpression) -> Option<String> {
@@ -686,27 +710,25 @@ fn get_cost_function_for_native(
686710
ContractOf => Some(Costs3::cost_contract_of),
687711
PrincipalOf => Some(Costs3::cost_principal_of),
688712
AtBlock => Some(Costs3::cost_at_block),
689-
// CreateMap => Some(Costs3::cost_create_map),
690-
// CreateVar => Some(Costs3::cost_create_var),
691-
// CreateNonFungibleToken => Some(Costs3::cost_create_nft),
692-
// CreateFungibleToken => Some(Costs3::cost_create_ft),
713+
// => Some(Costs3::cost_create_map),
714+
// => Some(Costs3::cost_create_var),
715+
// ContractStorage => Some(Costs3::cost_contract_storage),
693716
FetchEntry => Some(Costs3::cost_fetch_entry),
694717
SetEntry => Some(Costs3::cost_set_entry),
695718
FetchVar => Some(Costs3::cost_fetch_var),
696719
SetVar => Some(Costs3::cost_set_var),
697-
// ContractStorage => Some(Costs3::cost_contract_storage),
698720
GetBlockInfo => Some(Costs3::cost_block_info),
699721
GetBurnBlockInfo => Some(Costs3::cost_burn_block_info),
700722
GetStxBalance => Some(Costs3::cost_stx_balance),
701723
StxTransfer => Some(Costs3::cost_stx_transfer),
702724
StxTransferMemo => Some(Costs3::cost_stx_transfer_memo),
703725
StxGetAccount => Some(Costs3::cost_stx_account),
704726
MintToken => Some(Costs3::cost_ft_mint),
727+
MintAsset => Some(Costs3::cost_nft_mint),
705728
TransferToken => Some(Costs3::cost_ft_transfer),
706729
GetTokenBalance => Some(Costs3::cost_ft_balance),
707730
GetTokenSupply => Some(Costs3::cost_ft_get_supply),
708731
BurnToken => Some(Costs3::cost_ft_burn),
709-
MintAsset => Some(Costs3::cost_nft_mint),
710732
TransferAsset => Some(Costs3::cost_nft_transfer),
711733
GetAssetOwner => Some(Costs3::cost_nft_owner),
712734
BurnAsset => Some(Costs3::cost_nft_burn),
@@ -1052,7 +1074,6 @@ mod tests {
10521074
build_cost_analysis_tree(expr, &user_args, &cost_map, &ClarityVersion::Clarity3)
10531075
.unwrap();
10541076

1055-
// Should have 3 children: UserArgument for (x uint), UserArgument for (y uint), and the body (+ x y)
10561077
assert_eq!(cost_tree.children.len(), 3);
10571078

10581079
// First child should be UserArgument for (x uint)
@@ -1100,11 +1121,9 @@ mod tests {
11001121
let source = "(define-public (add (a uint) (b uint)) (+ a b))";
11011122
let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap();
11021123

1103-
// Should have one function
11041124
assert_eq!(ast_cost.len(), 1);
11051125
assert!(ast_cost.contains_key("add"));
11061126

1107-
// Check that the cost is reasonable (non-zero for addition)
11081127
let add_cost = ast_cost.get("add").unwrap();
11091128
assert!(add_cost.min.runtime > 0);
11101129
assert!(add_cost.max.runtime > 0);
@@ -1118,14 +1137,11 @@ mod tests {
11181137
"#;
11191138
let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap();
11201139

1121-
// Should have 2 functions
11221140
assert_eq!(ast_cost.len(), 2);
11231141

1124-
// Check that both functions are present
11251142
assert!(ast_cost.contains_key("func1"));
11261143
assert!(ast_cost.contains_key("func2"));
11271144

1128-
// Check that costs are reasonable
11291145
let func1_cost = ast_cost.get("func1").unwrap();
11301146
let func2_cost = ast_cost.get("func2").unwrap();
11311147
assert!(func1_cost.min.runtime > 0);

clarity/src/vm/tests/analysis.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// TODO: This needs work to get the dynamic vs static testing working
12
use std::collections::HashMap;
23

34
use rstest::rstest;

0 commit comments

Comments
 (0)