Skip to content

Commit a916e5c

Browse files
author
Gilad Chase
committed
feat: use local_into_box
- usage subject to dedicated db-flag - after tracking local variables in `analyze_ap_changes`, into_box sierra-gen conditionally emits either `into_box` or `local_into_box` if the input is small enough to be worth it.
1 parent 5678893 commit a916e5c

File tree

10 files changed

+305
-10
lines changed

10 files changed

+305
-10
lines changed

crates/cairo-lang-filesystem/src/flag.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,20 @@ pub enum Flag {
2020
///
2121
/// Default is false as it makes panic unprovable.
2222
UnsafePanic(bool),
23+
/// Whether to enable the local_into_box optimization.
24+
///
25+
/// Default is true.
26+
LocalIntoBoxOptimization(bool),
2327
}
2428

2529
/// Returns the value of the `unsafe_panic` flag, or `false` if the flag is not set.
2630
pub fn flag_unsafe_panic(db: &dyn salsa::Database) -> bool {
2731
let flag = FlagId::new(db, FlagLongId("unsafe_panic".into()));
2832
if let Some(flag) = db.get_flag(flag) { *flag == Flag::UnsafePanic(true) } else { false }
2933
}
34+
35+
/// Returns whether the local_into_box optimization is enabled (default: true).
36+
pub fn flag_local_into_box_optimization(db: &dyn salsa::Database) -> bool {
37+
let flag = FlagId::new(db, FlagLongId("local_into_box_optimization".into()));
38+
db.get_flag(flag).is_none_or(|f| *f != Flag::LocalIntoBoxOptimization(false))
39+
}

crates/cairo-lang-sierra-generator/src/block_generator.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
mod test;
44

55
use cairo_lang_diagnostics::Maybe;
6+
use cairo_lang_filesystem::flag::flag_local_into_box_optimization;
67
use cairo_lang_lowering::BlockId;
8+
use cairo_lang_lowering::db::LoweringGroup;
79
use cairo_lang_lowering::ids::LocationId;
810
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
911
use itertools::{chain, enumerate, zip_eq};
@@ -24,9 +26,9 @@ use crate::utils::{
2426
drop_libfunc_id, dup_libfunc_id, enable_ap_tracking_libfunc_id,
2527
enum_from_bounded_int_libfunc_id, enum_init_libfunc_id, get_concrete_libfunc_id,
2628
get_libfunc_signature, into_box_libfunc_id, jump_libfunc_id, jump_statement,
27-
match_enum_libfunc_id, rename_libfunc_id, return_statement, simple_basic_statement,
28-
snapshot_take_libfunc_id, struct_construct_libfunc_id, struct_deconstruct_libfunc_id,
29-
unbox_libfunc_id,
29+
local_into_box_libfunc_id, match_enum_libfunc_id, rename_libfunc_id, return_statement,
30+
simple_basic_statement, snapshot_take_libfunc_id, struct_construct_libfunc_id,
31+
struct_deconstruct_libfunc_id, unbox_libfunc_id,
3032
};
3133

3234
/// Generates Sierra code for the body of the given [lowering::Block].
@@ -564,11 +566,20 @@ fn generate_statement_into_box<'db>(
564566
statement_location: &StatementLocation,
565567
) -> Maybe<()> {
566568
let input = maybe_add_dup_statement(context, statement_location, 0, &statement.input)?;
569+
let ty = context.get_variable_sierra_type(statement.input.var_id)?;
570+
let db = context.get_db();
571+
let semantic_ty = context.get_lowered_variable(statement.input.var_id).ty;
572+
// When size < 3, into_box is cheaper.
573+
let use_local_into_box = flag_local_into_box_optimization(db)
574+
&& context.is_non_ap_based(statement.input.var_id)
575+
&& db.type_size(semantic_ty) >= 3;
576+
let libfunc_id = if use_local_into_box {
577+
local_into_box_libfunc_id(db, ty)
578+
} else {
579+
into_box_libfunc_id(db, ty)
580+
};
567581
let stmt = simple_basic_statement(
568-
into_box_libfunc_id(
569-
context.get_db(),
570-
context.get_variable_sierra_type(statement.input.var_id)?,
571-
),
582+
libfunc_id,
572583
&[input],
573584
&[context.get_sierra_variable(statement.output)],
574585
);

crates/cairo-lang-sierra-generator/src/block_generator_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ fn block_generator_test(
7979
function_id,
8080
&lifetime,
8181
crate::ap_tracking::ApTrackingConfiguration::default(),
82+
Default::default(),
8283
);
8384

8485
let mut expected_sierra_code = String::default();

crates/cairo-lang-sierra-generator/src/expr_generator_context.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use cairo_lang_sierra::extensions::uninitialized::UninitializedType;
66
use cairo_lang_sierra::program::{ConcreteTypeLongId, GenericArg};
77
use cairo_lang_utils::Intern;
88
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9+
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
910
use lowering::ids::ConcreteFunctionWithBodyId;
1011
use lowering::{BlockId, Lowered};
1112
use salsa::Database;
@@ -26,6 +27,7 @@ pub struct ExprGeneratorContext<'db, 'a> {
2627
var_id_allocator: IdAllocator,
2728
label_id_allocator: IdAllocator,
2829
variables: OrderedHashMap<SierraGenVar, cairo_lang_sierra::ids::VarId>,
30+
non_ap_based_variables: UnorderedHashSet<VariableId>,
2931
/// Allocated Sierra variables and their locations.
3032
variable_locations: Vec<(cairo_lang_sierra::ids::VarId, LocationId<'db>)>,
3133
block_labels: OrderedHashMap<BlockId, pre_sierra::LabelId<'db>>,
@@ -48,6 +50,7 @@ impl<'db, 'a> ExprGeneratorContext<'db, 'a> {
4850
function_id: ConcreteFunctionWithBodyId<'db>,
4951
lifetime: &'a VariableLifetimeResult,
5052
ap_tracking_configuration: ApTrackingConfiguration,
53+
non_ap_based_variables: UnorderedHashSet<VariableId>,
5154
) -> Self {
5255
ExprGeneratorContext {
5356
db,
@@ -61,6 +64,7 @@ impl<'db, 'a> ExprGeneratorContext<'db, 'a> {
6164
block_labels: OrderedHashMap::default(),
6265
ap_tracking_enabled: true,
6366
ap_tracking_configuration,
67+
non_ap_based_variables,
6468
statements: vec![],
6569
curr_cairo_location: None,
6670
}
@@ -189,6 +193,16 @@ impl<'db, 'a> ExprGeneratorContext<'db, 'a> {
189193
&& self.ap_tracking_configuration.disable_ap_tracking.contains(block_id)
190194
}
191195

196+
/// Returns true if the variable is non-AP-based.
197+
pub fn is_non_ap_based(&self, var_id: VariableId) -> bool {
198+
self.non_ap_based_variables.contains(&var_id)
199+
}
200+
201+
/// Returns the lowered variable for the given variable id.
202+
pub fn get_lowered_variable(&self, var_id: VariableId) -> &lowering::Variable<'db> {
203+
&self.lowered.variables[var_id]
204+
}
205+
192206
/// Adds a statement for the expression.
193207
pub fn push_statement(&mut self, statement: pre_sierra::Statement<'db>) {
194208
self.statements.push(pre_sierra::StatementWithLocation {

crates/cairo-lang-sierra-generator/src/function_generator.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,12 @@ fn get_function_ap_change_and_code<'db>(
6565
analyze_ap_change_result: AnalyzeApChangesResult,
6666
) -> Maybe<pre_sierra::Function<'db>> {
6767
let root_block = lowered_function.blocks.root_block()?;
68-
let AnalyzeApChangesResult { known_ap_change, local_variables, ap_tracking_configuration } =
69-
analyze_ap_change_result;
68+
let AnalyzeApChangesResult {
69+
known_ap_change,
70+
local_variables,
71+
ap_tracking_configuration,
72+
non_ap_based_variables,
73+
} = analyze_ap_change_result;
7074

7175
// Get lifetime information.
7276
let lifetime = find_variable_lifetime(lowered_function, &local_variables)?;
@@ -77,6 +81,7 @@ fn get_function_ap_change_and_code<'db>(
7781
function_id,
7882
&lifetime,
7983
ap_tracking_configuration,
84+
non_ap_based_variables,
8085
);
8186

8287
// If the function starts with `revoke_ap_tracking` then we can avoid
@@ -159,6 +164,7 @@ pub fn priv_get_dummy_function<'db>(
159164
function_id,
160165
&lifetime,
161166
ap_tracking_configuration,
167+
Default::default(),
162168
);
163169

164170
// Generate a label for the function's body.

crates/cairo-lang-sierra-generator/src/function_generator_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ cairo_lang_test_utils::test_file_test!(
44
function_generator,
55
"src/function_generator_test_data",
66
{
7+
boxing: "boxing",
78
inline: "inline",
89
struct_: "struct",
910
match_: "match",
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//! > Test local_into_box for large struct parameter (size >= 3).
2+
3+
//! > test_runner_name
4+
test_function_generator
5+
6+
//! > function_code
7+
fn foo(a: MyStruct) -> Box<MyStruct> {
8+
BoxTrait::new(a)
9+
}
10+
11+
//! > function_name
12+
foo
13+
14+
//! > module_code
15+
struct MyStruct {
16+
x: felt252,
17+
y: felt252,
18+
z: felt252,
19+
}
20+
21+
//! > semantic_diagnostics
22+
23+
//! > lowering_diagnostics
24+
25+
//! > sierra_gen_diagnostics
26+
27+
//! > sierra_code
28+
label_test::foo::0:
29+
local_into_box<test::MyStruct>([0]) -> ([1])
30+
store_temp<Box<test::MyStruct>>([1]) -> ([1])
31+
return([1])
32+
33+
//! > ==========================================================================
34+
35+
//! > Test into_box for small struct parameter (size < 3).
36+
37+
//! > test_runner_name
38+
test_function_generator
39+
40+
//! > function_code
41+
fn foo(a: SmallStruct) -> Box<SmallStruct> {
42+
BoxTrait::new(a)
43+
}
44+
45+
//! > function_name
46+
foo
47+
48+
//! > module_code
49+
struct SmallStruct {
50+
x: felt252,
51+
y: felt252,
52+
}
53+
54+
//! > semantic_diagnostics
55+
56+
//! > lowering_diagnostics
57+
58+
//! > sierra_gen_diagnostics
59+
60+
//! > sierra_code
61+
label_test::foo::0:
62+
into_box<test::SmallStruct>([0]) -> ([1])
63+
return([1])
64+
65+
//! > ==========================================================================
66+
67+
//! > Test chained boxing: inner uses local_into_box, outer uses into_box (Box has size 1).
68+
69+
//! > test_runner_name
70+
test_function_generator
71+
72+
//! > function_code
73+
fn foo(a: MyStruct) -> Box<Box<MyStruct>> {
74+
BoxTrait::new(BoxTrait::new(a))
75+
}
76+
77+
//! > function_name
78+
foo
79+
80+
//! > module_code
81+
struct MyStruct {
82+
x: felt252,
83+
y: felt252,
84+
z: felt252,
85+
}
86+
87+
//! > semantic_diagnostics
88+
89+
//! > lowering_diagnostics
90+
91+
//! > sierra_gen_diagnostics
92+
93+
//! > sierra_code
94+
label_test::foo::0:
95+
local_into_box<test::MyStruct>([0]) -> ([1])
96+
store_temp<Box<test::MyStruct>>([1]) -> ([1])
97+
into_box<Box<test::MyStruct>>([1]) -> ([2])
98+
return([2])
99+
100+
//! > ==========================================================================
101+
102+
//! > Test local_into_box for variable that becomes local due to revoke.
103+
104+
//! > test_runner_name
105+
test_function_generator
106+
107+
//! > function_code
108+
fn foo() -> Box<MyStruct> {
109+
let x = create_struct();
110+
revoke_ap();
111+
BoxTrait::new(x)
112+
}
113+
114+
//! > function_name
115+
foo
116+
117+
//! > module_code
118+
#[derive(Drop)]
119+
struct MyStruct {
120+
x: felt252,
121+
y: felt252,
122+
z: felt252,
123+
}
124+
125+
#[inline(never)]
126+
fn create_struct() -> MyStruct {
127+
MyStruct { x: 1, y: 2, z: 3 }
128+
}
129+
130+
fn revoke_ap() {
131+
revoke_ap()
132+
}
133+
134+
//! > semantic_diagnostics
135+
136+
//! > lowering_diagnostics
137+
138+
//! > sierra_gen_diagnostics
139+
140+
//! > sierra_code
141+
label_test::foo::0:
142+
alloc_local<test::MyStruct>() -> ([1])
143+
finalize_locals() -> ()
144+
disable_ap_tracking() -> ()
145+
function_call<user@test::create_struct>() -> ([0])
146+
store_local<test::MyStruct>([1], [0]) -> ([0])
147+
function_call<user@test::revoke_ap>() -> ()
148+
local_into_box<test::MyStruct>([0]) -> ([2])
149+
store_temp<Box<test::MyStruct>>([2]) -> ([2])
150+
return([2])
151+
152+
//! > ==========================================================================
153+
154+
//! > Test local_into_box for snapshot of parameter (alias is non-AP-based).
155+
156+
//! > test_runner_name
157+
test_function_generator
158+
159+
//! > function_code
160+
fn foo(a: MyStruct) -> Box<@MyStruct> {
161+
BoxTrait::new(@a)
162+
}
163+
164+
//! > function_name
165+
foo
166+
167+
//! > module_code
168+
#[derive(Drop)]
169+
struct MyStruct {
170+
x: felt252,
171+
y: felt252,
172+
z: felt252,
173+
}
174+
175+
//! > semantic_diagnostics
176+
177+
//! > lowering_diagnostics
178+
179+
//! > sierra_gen_diagnostics
180+
181+
//! > sierra_code
182+
label_test::foo::0:
183+
snapshot_take<test::MyStruct>([0]) -> ([1], [2])
184+
drop<test::MyStruct>([1]) -> ()
185+
local_into_box<test::MyStruct>([2]) -> ([3])
186+
store_temp<Box<test::MyStruct>>([3]) -> ([3])
187+
return([3])

crates/cairo-lang-sierra-generator/src/local_variables.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ use crate::utils::{
3131

3232
/// Information returned by [analyze_ap_changes].
3333
pub struct AnalyzeApChangesResult {
34-
/// True if the function has a known_ap_change
34+
/// True if the function has a known_ap_change.
3535
pub known_ap_change: bool,
3636
/// The variables that should be stored in locals as they are revoked during the function.
3737
pub local_variables: OrderedHashSet<VariableId>,
3838
/// Information about where ap tracking should be enabled and disabled.
3939
pub ap_tracking_configuration: ApTrackingConfiguration,
40+
/// Variables that are known to be non-AP-based (expanded to include all aliases).
41+
pub non_ap_based_variables: UnorderedHashSet<VariableId>,
4042
}
4143

4244
/// Does ap change related analysis for a given function.
@@ -98,6 +100,14 @@ pub fn analyze_ap_changes<'db>(
98100
}
99101
}
100102

103+
// Expand non_ap_based to include all aliases.
104+
let non_ap_based_variables: UnorderedHashSet<_> = lowered_function
105+
.variables
106+
.iter()
107+
.map(|(id, _)| id)
108+
.filter(|v| ctx.non_ap_based.contains(ctx.peel_aliases(v)))
109+
.collect();
110+
101111
Ok(AnalyzeApChangesResult {
102112
known_ap_change: root_info.known_ap_change,
103113
local_variables: locals,
@@ -106,6 +116,7 @@ pub fn analyze_ap_changes<'db>(
106116
root_info.known_ap_change,
107117
need_ap_alignment,
108118
),
119+
non_ap_based_variables,
109120
})
110121
}
111122

0 commit comments

Comments
 (0)