Skip to content

Commit fc009fe

Browse files
committed
use Environment for static_cost
1 parent 1e75681 commit fc009fe

File tree

2 files changed

+176
-90
lines changed

2 files changed

+176
-90
lines changed

clarity/src/vm/costs/analysis.rs

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use clarity_types::types::{CharType, SequenceData, TraitIdentifier};
66
use stacks_common::types::StacksEpochId;
77

88
use crate::vm::ast::build_ast;
9+
use crate::vm::contexts::Environment;
910
use crate::vm::costs::cost_functions::{linear, CostValues};
1011
use crate::vm::costs::costs_3::Costs3;
1112
use crate::vm::costs::ExecutionCost;
@@ -221,18 +222,17 @@ fn static_cost_native(
221222
let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree);
222223
Ok(summing_cost.into())
223224
}
224-
/// Parse Clarity source code and calculate its static execution cost for the specified function
225-
pub fn static_cost(
226-
source: &str,
225+
226+
pub fn static_cost_from_ast(
227+
contract_ast: &crate::vm::ast::ContractAST,
227228
clarity_version: &ClarityVersion,
228229
) -> Result<HashMap<String, StaticCost>, String> {
229-
let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version
230-
let ast = make_ast(source, epoch, clarity_version)?;
230+
let exprs = &contract_ast.expressions;
231231

232-
if ast.expressions.is_empty() {
233-
return Err("No expressions found".to_string());
232+
if exprs.is_empty() {
233+
return Err("No expressions found in contract AST".to_string());
234234
}
235-
let exprs = &ast.expressions;
235+
236236
let user_args = UserArgumentsContext::new();
237237
let mut costs: HashMap<String, Option<StaticCost>> = HashMap::new();
238238

@@ -260,6 +260,40 @@ pub fn static_cost(
260260
.collect())
261261
}
262262

263+
/// Calculate static execution cost for functions using Environment context
264+
/// This replaces the old source-string based approach with Environment integration
265+
pub fn static_cost(
266+
env: &mut Environment,
267+
contract_identifier: &QualifiedContractIdentifier,
268+
) -> Result<HashMap<String, StaticCost>, String> {
269+
// Get the contract source from the environment's database
270+
let contract_source = env
271+
.global_context
272+
.database
273+
.get_contract_src(contract_identifier)
274+
.ok_or_else(|| "Contract source not found in database".to_string())?;
275+
276+
// Get the contract's clarity version from the environment
277+
let contract = env
278+
.global_context
279+
.database
280+
.get_contract(contract_identifier)
281+
.map_err(|e| format!("Failed to get contract: {:?}", e))?;
282+
283+
let clarity_version = contract.contract_context.get_clarity_version();
284+
285+
let epoch = env.global_context.epoch_id;
286+
let ast = make_ast(&contract_source, epoch, clarity_version)?;
287+
288+
static_cost_from_ast(&ast, clarity_version)
289+
}
290+
291+
// pub fn static_cost_tree(
292+
// source: &str,
293+
// clarity_version: &ClarityVersion,
294+
// ) -> Result<HashMap<String, CostAnalysisNode>, String> {
295+
// }
296+
263297
/// Extract function name from a symbolic expression
264298
fn extract_function_name(expr: &SymbolicExpression) -> Option<String> {
265299
if let Some(list) = expr.match_list() {
@@ -469,7 +503,6 @@ fn build_listlike_cost_analysis_tree(
469503
(CostExprNode::NativeFunction(native_function), cost)
470504
} else {
471505
// If not a native function, treat as user-defined function and look it up
472-
println!("in user-defined function");
473506
let expr_node = CostExprNode::UserFunction(function_name.clone());
474507
let cost = calculate_function_cost(function_name.to_string(), cost_map, clarity_version)?;
475508
(expr_node, cost)
@@ -499,6 +532,7 @@ fn calculate_function_cost(
499532
Ok(cost.clone())
500533
}
501534
Some(None) => {
535+
// Should be impossible but alas..
502536
// Function exists but cost not yet computed - this indicates a circular dependency
503537
// For now, return zero cost to avoid infinite recursion
504538
println!(
@@ -809,6 +843,15 @@ mod tests {
809843
static_cost_native(source, &cost_map, clarity_version)
810844
}
811845

846+
fn static_cost_test(
847+
source: &str,
848+
clarity_version: &ClarityVersion,
849+
) -> Result<HashMap<String, StaticCost>, String> {
850+
let epoch = StacksEpochId::latest();
851+
let ast = make_ast(source, epoch, clarity_version)?;
852+
static_cost_from_ast(&ast, clarity_version)
853+
}
854+
812855
fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST {
813856
let contract_identifier = QualifiedContractIdentifier::transient();
814857
let mut cost_tracker = ();
@@ -1051,4 +1094,41 @@ mod tests {
10511094
assert_eq!(arg_type.as_str(), "uint");
10521095
}
10531096
}
1097+
1098+
#[test]
1099+
fn test_static_cost_simple_addition() {
1100+
let source = "(define-public (add (a uint) (b uint)) (+ a b))";
1101+
let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap();
1102+
1103+
// Should have one function
1104+
assert_eq!(ast_cost.len(), 1);
1105+
assert!(ast_cost.contains_key("add"));
1106+
1107+
// Check that the cost is reasonable (non-zero for addition)
1108+
let add_cost = ast_cost.get("add").unwrap();
1109+
assert!(add_cost.min.runtime > 0);
1110+
assert!(add_cost.max.runtime > 0);
1111+
}
1112+
1113+
#[test]
1114+
fn test_static_cost_multiple_functions() {
1115+
let source = r#"
1116+
(define-public (func1 (x uint)) (+ x 1))
1117+
(define-private (func2 (y uint)) (* y 2))
1118+
"#;
1119+
let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap();
1120+
1121+
// Should have 2 functions
1122+
assert_eq!(ast_cost.len(), 2);
1123+
1124+
// Check that both functions are present
1125+
assert!(ast_cost.contains_key("func1"));
1126+
assert!(ast_cost.contains_key("func2"));
1127+
1128+
// Check that costs are reasonable
1129+
let func1_cost = ast_cost.get("func1").unwrap();
1130+
let func2_cost = ast_cost.get("func2").unwrap();
1131+
assert!(func1_cost.min.runtime > 0);
1132+
assert!(func2_cost.min.runtime > 0);
1133+
}
10541134
}

clarity/src/vm/tests/analysis.rs

Lines changed: 87 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,12 @@
1-
// Copyright (C) 2013-2020 Blockstack PBC, a public benefit corporation
2-
// Copyright (C) 2020 Stacks Open Internet Foundation
3-
//
4-
// This program is free software: you can redistribute it and/or modify
5-
// it under the terms of the GNU General Public License as published by
6-
// the Free Software Foundation, either version 3 of the License, or
7-
// (at your option) any later version.
8-
//
9-
// This program is distributed in the hope that it will be useful,
10-
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11-
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12-
// GNU General Public License for more details.
13-
//
14-
// You should have received a copy of the GNU General Public License
15-
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16-
171
use std::collections::HashMap;
182

193
use rstest::rstest;
204
use stacks_common::types::StacksEpochId;
215

226
use crate::vm::contexts::OwnedEnvironment;
23-
use crate::vm::costs::analysis::{build_cost_analysis_tree, static_cost, UserArgumentsContext};
7+
use crate::vm::costs::analysis::{
8+
build_cost_analysis_tree, static_cost_from_ast, UserArgumentsContext,
9+
};
2410
use crate::vm::costs::ExecutionCost;
2511
use crate::vm::tests::{tl_env_factory, TopLevelMemoryEnvironmentGenerator};
2612
use crate::vm::types::{PrincipalData, QualifiedContractIdentifier};
@@ -38,17 +24,23 @@ fn test_simple_trait_implementation_costs(
3824
#[case] epoch: StacksEpochId,
3925
mut tl_env_factory: TopLevelMemoryEnvironmentGenerator,
4026
) {
41-
// Simple trait implementation - very brief function that basically does nothing
4227
let simple_impl = r#"(impl-trait .mytrait.mytrait)
4328
(define-public (somefunc (a uint) (b uint))
4429
(ok (+ a b))
4530
)"#;
4631

47-
// Set up environment with cost tracking - use regular environment but try to get actual costs
4832
let mut owned_env = tl_env_factory.get_env(epoch);
4933

50-
// Get static cost analysis
51-
let static_cost = static_cost(simple_impl, &version).unwrap();
34+
let epoch = StacksEpochId::Epoch21;
35+
let ast = crate::vm::ast::build_ast(
36+
&QualifiedContractIdentifier::transient(),
37+
simple_impl,
38+
&mut (),
39+
version,
40+
epoch,
41+
)
42+
.unwrap();
43+
let static_cost = static_cost_from_ast(&ast, &version).unwrap();
5244
// Deploy and execute the contract to get dynamic costs
5345
let contract_id = QualifiedContractIdentifier::local("simple-impl").unwrap();
5446
owned_env
@@ -71,71 +63,13 @@ fn test_simple_trait_implementation_costs(
7163
assert!(dynamic_cost.runtime <= cost.max.runtime);
7264
}
7365

74-
/// Helper function to execute a contract function and return the execution cost
75-
fn execute_contract_function_and_get_cost(
76-
env: &mut OwnedEnvironment,
77-
contract_id: &QualifiedContractIdentifier,
78-
function_name: &str,
79-
args: &[u64],
80-
version: ClarityVersion,
81-
) -> ExecutionCost {
82-
// Start with a fresh cost tracker
83-
let initial_cost = env.get_cost_total();
84-
85-
// Create a dummy sender
86-
let sender = PrincipalData::parse_qualified_contract_principal(
87-
"ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender",
88-
)
89-
.unwrap();
90-
91-
// Build function call string
92-
let arg_str = args
93-
.iter()
94-
.map(|a| format!("u{}", a))
95-
.collect::<Vec<_>>()
96-
.join(" ");
97-
let function_call = format!("({} {})", function_name, arg_str);
98-
99-
// Parse the function call into a symbolic expression
100-
let ast = crate::vm::ast::parse(
101-
&QualifiedContractIdentifier::transient(),
102-
&function_call,
103-
version,
104-
StacksEpochId::Epoch21,
105-
)
106-
.expect("Failed to parse function call");
107-
108-
if !ast.is_empty() {
109-
let _result = env.execute_transaction(
110-
sender,
111-
None,
112-
contract_id.clone(),
113-
&function_call,
114-
&ast[0..1],
115-
);
116-
}
117-
118-
// Get the cost after execution
119-
let final_cost = env.get_cost_total();
120-
121-
// Return the difference
122-
ExecutionCost {
123-
write_length: final_cost.write_length - initial_cost.write_length,
124-
write_count: final_cost.write_count - initial_cost.write_count,
125-
read_length: final_cost.read_length - initial_cost.read_length,
126-
read_count: final_cost.read_count - initial_cost.read_count,
127-
runtime: final_cost.runtime - initial_cost.runtime,
128-
}
129-
}
130-
13166
#[rstest]
13267
#[case::clarity2(ClarityVersion::Clarity2, StacksEpochId::Epoch21)]
13368
fn test_complex_trait_implementation_costs(
13469
#[case] version: ClarityVersion,
13570
#[case] epoch: StacksEpochId,
13671
mut tl_env_factory: TopLevelMemoryEnvironmentGenerator,
13772
) {
138-
// Complex trait implementation with expensive operations but no external calls
13973
let complex_impl = r#"(define-public (somefunc (a uint) (b uint))
14074
(begin
14175
;; do something expensive
@@ -152,7 +86,16 @@ fn test_complex_trait_implementation_costs(
15286

15387
let mut owned_env = tl_env_factory.get_env(epoch);
15488

155-
let static_cost_result = static_cost(complex_impl, &version);
89+
let epoch = StacksEpochId::Epoch21;
90+
let ast = crate::vm::ast::build_ast(
91+
&QualifiedContractIdentifier::transient(),
92+
complex_impl,
93+
&mut (),
94+
version,
95+
epoch,
96+
)
97+
.unwrap();
98+
let static_cost_result = static_cost_from_ast(&ast, &version);
15699
match static_cost_result {
157100
Ok(static_cost) => {
158101
let contract_id = QualifiedContractIdentifier::local("complex-impl").unwrap();
@@ -228,11 +171,74 @@ fn test_dependent_function_calls() {
228171
)"#;
229172

230173
let contract_id = QualifiedContractIdentifier::transient();
231-
let function_map = static_cost(src, &ClarityVersion::Clarity3).unwrap();
174+
let epoch = StacksEpochId::Epoch32;
175+
let ast = crate::vm::ast::build_ast(
176+
&QualifiedContractIdentifier::transient(),
177+
src,
178+
&mut (),
179+
ClarityVersion::Clarity3,
180+
epoch,
181+
)
182+
.unwrap();
183+
let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3).unwrap();
232184

233185
let add_one_cost = function_map.get("add-one").unwrap();
234186
let somefunc_cost = function_map.get("somefunc").unwrap();
235187

236188
assert!(add_one_cost.min.runtime >= somefunc_cost.min.runtime);
237189
assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime);
238190
}
191+
192+
/// Helper function to execute a contract function and return the execution cost
193+
fn execute_contract_function_and_get_cost(
194+
env: &mut OwnedEnvironment,
195+
contract_id: &QualifiedContractIdentifier,
196+
function_name: &str,
197+
args: &[u64],
198+
version: ClarityVersion,
199+
) -> ExecutionCost {
200+
// Start with a fresh cost tracker
201+
let initial_cost = env.get_cost_total();
202+
203+
// Create a dummy sender
204+
let sender = PrincipalData::parse_qualified_contract_principal(
205+
"ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender",
206+
)
207+
.unwrap();
208+
209+
let arg_str = args
210+
.iter()
211+
.map(|a| format!("u{}", a))
212+
.collect::<Vec<_>>()
213+
.join(" ");
214+
let function_call = format!("({} {})", function_name, arg_str);
215+
216+
let ast = crate::vm::ast::parse(
217+
&QualifiedContractIdentifier::transient(),
218+
&function_call,
219+
version,
220+
StacksEpochId::Epoch21,
221+
)
222+
.expect("Failed to parse function call");
223+
224+
if !ast.is_empty() {
225+
let _result = env.execute_transaction(
226+
sender,
227+
None,
228+
contract_id.clone(),
229+
&function_call,
230+
&ast[0..1],
231+
);
232+
}
233+
234+
// Get the cost after execution
235+
let final_cost = env.get_cost_total();
236+
237+
ExecutionCost {
238+
write_length: final_cost.write_length - initial_cost.write_length,
239+
write_count: final_cost.write_count - initial_cost.write_count,
240+
read_length: final_cost.read_length - initial_cost.read_length,
241+
read_count: final_cost.read_count - initial_cost.read_count,
242+
runtime: final_cost.runtime - initial_cost.runtime,
243+
}
244+
}

0 commit comments

Comments
 (0)