Skip to content

Commit 17ff3bf

Browse files
committed
trait counting on map/filter/fold minimum to zero
1 parent 75d3c7d commit 17ff3bf

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

clarity/src/vm/costs/analysis.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -226,18 +226,18 @@ type TraitCount = HashMap<String, MinMaxTraitCount>;
226226
/// Context passed to visitors during trait count analysis
227227
struct TraitCountContext {
228228
containing_fn_name: String,
229-
multiplier: u64,
229+
multiplier: (u64, u64),
230230
}
231231

232232
impl TraitCountContext {
233-
fn new(containing_fn_name: String, multiplier: u64) -> Self {
233+
fn new(containing_fn_name: String, multiplier: (u64, u64)) -> Self {
234234
Self {
235235
containing_fn_name,
236236
multiplier,
237237
}
238238
}
239239

240-
fn with_multiplier(&self, multiplier: u64) -> Self {
240+
fn with_multiplier(&self, multiplier: (u64, u64)) -> Self {
241241
Self {
242242
containing_fn_name: self.containing_fn_name.clone(),
243243
multiplier,
@@ -254,51 +254,54 @@ impl TraitCountContext {
254254

255255
/// Extract the list size multiplier from a list expression (for map/filter/fold operations)
256256
/// Expects a list in the form `(list <size>)` where size is an integer literal
257-
fn extract_list_multiplier(list: &[SymbolicExpression]) -> u64 {
257+
fn extract_list_multiplier(list: &[SymbolicExpression]) -> (u64, u64) {
258258
if list.is_empty() {
259-
return 1;
259+
return (1, 1);
260260
}
261261

262262
let is_list_atom = list[0]
263263
.match_atom()
264264
.map(|a| a.as_str() == "list")
265265
.unwrap_or(false);
266266
if !is_list_atom || list.len() < 2 {
267-
return 1;
267+
return (1, 1);
268268
}
269269

270270
match &list[1].expr {
271-
SymbolicExpressionType::LiteralValue(Value::Int(value)) => *value as u64,
272-
_ => 1,
271+
SymbolicExpressionType::LiteralValue(Value::Int(value)) => (0, *value as u64),
272+
_ => (1, 1),
273273
}
274274
}
275275

276276
/// Increment trait count for a function
277-
fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: u64) {
277+
fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: (u64, u64)) {
278278
trait_counts
279279
.entry(fn_name.to_string())
280280
.and_modify(|(min, max)| {
281-
*min += 1;
282-
*max += multiplier;
281+
*min += multiplier.0;
282+
*max += multiplier.1;
283283
})
284-
.or_insert((1, multiplier));
284+
.or_insert(multiplier);
285285
}
286286

287287
/// Propagate trait count from one function to another with a multiplier
288288
fn propagate_trait_count(
289289
trait_counts: &mut TraitCount,
290290
from_fn: &str,
291291
to_fn: &str,
292-
multiplier: u64,
292+
multiplier: (u64, u64),
293293
) {
294294
if let Some(called_trait_count) = trait_counts.get(from_fn).cloned() {
295295
trait_counts
296296
.entry(to_fn.to_string())
297297
.and_modify(|(min, max)| {
298-
*min += called_trait_count.0;
299-
*max += called_trait_count.1 * multiplier;
298+
*min += called_trait_count.0 * multiplier.0;
299+
*max += called_trait_count.1 * multiplier.1;
300300
})
301-
.or_insert((called_trait_count.0, called_trait_count.1 * multiplier));
301+
.or_insert((
302+
called_trait_count.0 * multiplier.0,
303+
called_trait_count.1 * multiplier.1,
304+
));
302305
}
303306
}
304307

@@ -409,7 +412,7 @@ impl TraitCountVisitor for TraitCountCollector {
409412
{
410413
extract_list_multiplier(list)
411414
} else {
412-
1
415+
(1, 1)
413416
};
414417
let new_context = context.with_multiplier(multiplier);
415418
for child in &node.children {
@@ -535,7 +538,7 @@ impl<'a> TraitCountVisitor for TraitCountPropagator<'a> {
535538
{
536539
extract_list_multiplier(list)
537540
} else {
538-
1
541+
(1, 1)
539542
};
540543

541544
// Process the function being called in map/filter/fold
@@ -630,7 +633,7 @@ pub(crate) fn get_trait_count(costs: &HashMap<String, CostAnalysisNode>) -> Opti
630633
// First pass: collect trait counts and trait names
631634
let mut collector = TraitCountCollector::new();
632635
for (name, cost_analysis_node) in costs.iter() {
633-
let context = TraitCountContext::new(name.clone(), 1);
636+
let context = TraitCountContext::new(name.clone(), (1, 1));
634637
collector.visit(cost_analysis_node, &context);
635638
}
636639

@@ -640,7 +643,7 @@ pub(crate) fn get_trait_count(costs: &HashMap<String, CostAnalysisNode>) -> Opti
640643
let mut propagator =
641644
TraitCountPropagator::new(&mut collector.trait_counts, &collector.trait_names);
642645
for (name, cost_analysis_node) in costs.iter() {
643-
let context = TraitCountContext::new(name.clone(), 1);
646+
let context = TraitCountContext::new(name.clone(), (1, 1));
644647
propagator.visit(cost_analysis_node, &context);
645648
}
646649

clarity/src/vm/tests/analysis.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ fn test_get_trait_count_direct() {
225225
// Expected result: {something: (1,10), send: (1,1)}
226226
let expected = {
227227
let mut map = HashMap::new();
228-
map.insert("something".to_string(), (1, 10));
228+
map.insert("something".to_string(), (0, 10));
229229
map.insert("send".to_string(), (1, 1));
230230
Some(map)
231231
};
@@ -271,7 +271,7 @@ fn test_trait_counting() {
271271
// Check that "something" function has trait count of (1, 10)
272272
let something_trait_count_map = static_cost.get("something").unwrap().1.clone().unwrap();
273273
let something_trait_count = something_trait_count_map.get("something").unwrap();
274-
assert_eq!(something_trait_count.0, 1);
274+
assert_eq!(something_trait_count.0, 0);
275275
assert_eq!(something_trait_count.1, 10);
276276
}
277277

0 commit comments

Comments
 (0)