Skip to content

Commit ffa5a5e

Browse files
allow shallow specialization of recursive functions
1 parent 7f4a9e8 commit ffa5a5e

File tree

27 files changed

+7407
-6959
lines changed

27 files changed

+7407
-6959
lines changed

crates/cairo-lang-lowering/src/add_withdraw_gas/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ fn create_panic_block<'db>(
123123
with_coupon: false,
124124
outputs: vec![never_var],
125125
location,
126+
is_specialization_base_call: false,
126127
})],
127128
end: BlockEnd::Match {
128129
info: MatchInfo::Enum(MatchEnumInfo {

crates/cairo-lang-lowering/src/cache/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,7 @@ struct StatementCallCached {
988988
with_coupon: bool,
989989
outputs: Vec<usize>,
990990
location: LocationIdCached,
991+
is_specialization_base_call: bool,
991992
}
992993
impl StatementCallCached {
993994
fn new<'db>(stmt: StatementCall<'db>, ctx: &mut CacheSavingContext<'db>) -> Self {
@@ -997,6 +998,7 @@ impl StatementCallCached {
997998
with_coupon: stmt.with_coupon,
998999
outputs: stmt.outputs.iter().map(|var| var.index()).collect(),
9991000
location: LocationIdCached::new(stmt.location, ctx),
1001+
is_specialization_base_call: stmt.is_specialization_base_call,
10001002
}
10011003
}
10021004
fn embed<'db>(self, ctx: &mut CacheLoadingContext<'db>) -> StatementCall<'db> {
@@ -1010,6 +1012,7 @@ impl StatementCallCached {
10101012
.map(|var_id| ctx.lowered_variables_id[var_id])
10111013
.collect(),
10121014
location: self.location.embed(ctx),
1015+
is_specialization_base_call: self.is_specialization_base_call,
10131016
}
10141017
}
10151018
}

crates/cairo-lang-lowering/src/destructs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ pub fn add_destructs<'db>(
389389
with_coupon: false,
390390
outputs: vec![output_var],
391391
location: variables.variables[plain_destruct.var_id].location,
392+
is_specialization_base_call: false,
392393
})
393394
}
394395

@@ -415,6 +416,7 @@ pub fn add_destructs<'db>(
415416
with_coupon: false,
416417
outputs: vec![new_panic_var, output_var],
417418
location,
419+
is_specialization_base_call: false,
418420
});
419421
last_panic_var = new_panic_var;
420422
}

crates/cairo-lang-lowering/src/inline/mod.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,16 @@ pub fn priv_should_inline<'db>(
6161
if db.priv_never_inline(function_id)? {
6262
return Ok(false);
6363
}
64-
65-
// Breaks cycles.
66-
if db.concrete_in_cycle(function_id, DependencyType::Call, LoweringStage::Monomorphized)? {
64+
// Prevents inlining of functions that may call themselves, by checking if the base of the
65+
// function (the function without specialization) is in a call cycle (we cannot use the
66+
// specialized version as it may call the base function with different specialization which may
67+
// cause long inlining chains).
68+
let base = match function_id.long(db) {
69+
ConcreteFunctionWithBodyLongId::Semantic(_)
70+
| ConcreteFunctionWithBodyLongId::Generated(_) => function_id,
71+
ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
72+
};
73+
if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
6774
return Ok(false);
6875
}
6976

@@ -342,6 +349,7 @@ where
342349
if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
343350
calling_function_id.long(db)
344351
&& specialized.base == called_func
352+
&& stmt.is_specialization_base_call
345353
{
346354
// A specialized function should always inline its base.
347355
return Ok(Some((stmt, called_func)));

crates/cairo-lang-lowering/src/lower/generators.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ impl<'db> Call<'db> {
9393
with_coupon,
9494
outputs,
9595
location: self.location,
96+
is_specialization_base_call: false,
9697
}));
9798

9899
CallResult { returns, extra_outputs }

crates/cairo-lang-lowering/src/lower/specialized_test.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use cairo_lang_debug::DebugWithDb;
22
use cairo_lang_semantic::test_utils::setup_test_function;
33
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
44
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
5+
use cairo_lang_utils::try_extract_matches;
56
use salsa::Setter;
67

78
use crate::db::{LoweringGroup, lowering_group_input};
@@ -56,17 +57,18 @@ fn test_specialized_function(
5657
"Got diagnostics for the caller {semantic_diagnostics}\n{lowering_diagnostics}."
5758
);
5859
});
59-
let Some(Statement::Call(call)) = lowered_caller
60+
let Some(specialized_id) = lowered_caller
6061
.blocks
6162
.iter()
6263
.flat_map(|(_, b)| b.statements.iter())
63-
.rfind(|statement| matches!(statement, Statement::Call(_)))
64+
.filter_map(|statement| {
65+
try_extract_matches!(statement, Statement::Call)
66+
.and_then(|call| call.function.body(db).ok().flatten())
67+
})
68+
.next_back()
6469
else {
6570
panic!("Could not find the last call in the caller function.");
6671
};
67-
let Ok(Some(specialized_id)) = call.function.body(db) else {
68-
panic!("Expected function body, got: {}", call.function.full_path(db));
69-
};
7072
let lowered_specialized =
7173
db.lowered_body(specialized_id, LoweringStage::Monomorphized).unwrap();
7274
let lowered_formatter = LoweredFormatter::new(db, &lowered_caller.variables);

crates/cairo-lang-lowering/src/lower/test_data/loop

Lines changed: 74 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,27 @@ Final lowering:
4242
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
4343
blk0 (root):
4444
Statements:
45-
(v2: core::felt252) <- 5
46-
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-149](v0, v1, v2)
45+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-149]{5, }(v0, v1)
4746
End:
48-
Match(match_enum(v5) {
49-
PanicResult::Ok(v6) => blk1,
50-
PanicResult::Err(v7) => blk2,
47+
Match(match_enum(v4) {
48+
PanicResult::Ok(v5) => blk1,
49+
PanicResult::Err(v6) => blk2,
5150
})
5251

5352
blk1:
5453
Statements:
55-
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
56-
(v9: core::felt252, v10: core::bool) <- struct_destructure(v6)
57-
(v11: (core::bool,)) <- struct_construct(v10)
58-
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v11)
54+
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
55+
(v8: core::felt252, v9: core::bool) <- struct_destructure(v5)
56+
(v10: (core::bool,)) <- struct_construct(v9)
57+
(v11: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v10)
5958
End:
60-
Return(v3, v8, v12)
59+
Return(v2, v7, v11)
6160

6261
blk2:
6362
Statements:
64-
(v13: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v7)
63+
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v6)
6564
End:
66-
Return(v3, v4, v13)
65+
Return(v2, v3, v12)
6766

6867

6968
Generated loop lowering for source location:
@@ -198,28 +197,27 @@ Final lowering:
198197
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
199198
blk0 (root):
200199
Statements:
201-
(v2: core::felt252) <- 5
202-
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-130](v0, v1, v2)
200+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-130]{5, }(v0, v1)
203201
End:
204-
Match(match_enum(v5) {
205-
PanicResult::Ok(v6) => blk1,
206-
PanicResult::Err(v7) => blk2,
202+
Match(match_enum(v4) {
203+
PanicResult::Ok(v5) => blk1,
204+
PanicResult::Err(v6) => blk2,
207205
})
208206

209207
blk1:
210208
Statements:
211-
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
212-
(v9: core::felt252, v10: core::bool) <- struct_destructure(v6)
213-
(v11: (core::bool,)) <- struct_construct(v10)
214-
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v11)
209+
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
210+
(v8: core::felt252, v9: core::bool) <- struct_destructure(v5)
211+
(v10: (core::bool,)) <- struct_construct(v9)
212+
(v11: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v10)
215213
End:
216-
Return(v3, v8, v12)
214+
Return(v2, v7, v11)
217215

218216
blk2:
219217
Statements:
220-
(v13: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v7)
218+
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v6)
221219
End:
222-
Return(v3, v4, v13)
220+
Return(v2, v3, v12)
223221

224222

225223
Generated loop lowering for source location:
@@ -521,28 +519,27 @@ Final lowering:
521519
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
522520
blk0 (root):
523521
Statements:
524-
(v2: core::felt252) <- 5
525-
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-201](v0, v1, v2)
522+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-201]{5, }(v0, v1)
526523
End:
527-
Match(match_enum(v5) {
528-
PanicResult::Ok(v6) => blk1,
529-
PanicResult::Err(v7) => blk2,
524+
Match(match_enum(v4) {
525+
PanicResult::Ok(v5) => blk1,
526+
PanicResult::Err(v6) => blk2,
530527
})
531528

532529
blk1:
533530
Statements:
534-
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
535-
(v9: core::felt252, v10: core::bool) <- struct_destructure(v6)
536-
(v11: (core::bool,)) <- struct_construct(v10)
537-
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v11)
531+
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
532+
(v8: core::felt252, v9: core::bool) <- struct_destructure(v5)
533+
(v10: (core::bool,)) <- struct_construct(v9)
534+
(v11: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v10)
538535
End:
539-
Return(v3, v8, v12)
536+
Return(v2, v7, v11)
540537

541538
blk2:
542539
Statements:
543-
(v13: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v7)
540+
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v6)
544541
End:
545-
Return(v3, v4, v13)
542+
Return(v2, v3, v12)
546543

547544

548545
Generated loop lowering for source location:
@@ -1394,29 +1391,27 @@ Final lowering:
13941391
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
13951392
blk0 (root):
13961393
Statements:
1397-
(v2: core::felt252) <- 0
1398-
(v3: test::A) <- struct_construct(v2)
1399-
(v4: core::RangeCheck, v5: core::gas::GasBuiltin, v6: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[58-134](v0, v1, v2, v3)
1394+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[58-134]{0, { 0: core::felt252 }, }(v0, v1)
14001395
End:
1401-
Match(match_enum(v6) {
1402-
PanicResult::Ok(v7) => blk1,
1403-
PanicResult::Err(v8) => blk2,
1396+
Match(match_enum(v4) {
1397+
PanicResult::Ok(v5) => blk1,
1398+
PanicResult::Err(v6) => blk2,
14041399
})
14051400

14061401
blk1:
14071402
Statements:
1408-
(v9: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v5)
1409-
(v10: ()) <- struct_construct()
1410-
(v11: ((),)) <- struct_construct(v10)
1411-
(v12: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v11)
1403+
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
1404+
(v8: ()) <- struct_construct()
1405+
(v9: ((),)) <- struct_construct(v8)
1406+
(v10: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v9)
14121407
End:
1413-
Return(v4, v9, v12)
1408+
Return(v2, v7, v10)
14141409

14151410
blk2:
14161411
Statements:
1417-
(v13: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v8)
1412+
(v11: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v6)
14181413
End:
1419-
Return(v4, v5, v13)
1414+
Return(v2, v3, v11)
14201415

14211416

14221417
Generated loop lowering for source location:
@@ -1561,28 +1556,27 @@ Final lowering:
15611556
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
15621557
blk0 (root):
15631558
Statements:
1564-
(v2: core::integer::u8) <- 0
1565-
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[54-175](v0, v1, v2)
1559+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[54-175]{0, }(v0, v1)
15661560
End:
1567-
Match(match_enum(v5) {
1568-
PanicResult::Ok(v6) => blk1,
1569-
PanicResult::Err(v7) => blk2,
1561+
Match(match_enum(v4) {
1562+
PanicResult::Ok(v5) => blk1,
1563+
PanicResult::Err(v6) => blk2,
15701564
})
15711565

15721566
blk1:
15731567
Statements:
1574-
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
1575-
(v9: ()) <- struct_construct()
1576-
(v10: ((),)) <- struct_construct(v9)
1577-
(v11: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v10)
1568+
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
1569+
(v8: ()) <- struct_construct()
1570+
(v9: ((),)) <- struct_construct(v8)
1571+
(v10: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v9)
15781572
End:
1579-
Return(v3, v8, v11)
1573+
Return(v2, v7, v10)
15801574

15811575
blk2:
15821576
Statements:
1583-
(v12: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v7)
1577+
(v11: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v6)
15841578
End:
1585-
Return(v3, v4, v12)
1579+
Return(v2, v3, v11)
15861580

15871581
//! > lowering_diagnostics
15881582

@@ -1657,45 +1651,44 @@ Final lowering:
16571651
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
16581652
blk0 (root):
16591653
Statements:
1660-
(v2: core::felt252) <- 5
1661-
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::internal::LoopResult::<core::bool, core::integer::u32>)>) <- test::foo[37-253](v0, v1, v2)
1654+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::internal::LoopResult::<core::bool, core::integer::u32>)>) <- test::foo[37-253]{5, }(v0, v1)
16621655
End:
1663-
Match(match_enum(v5) {
1664-
PanicResult::Ok(v6) => blk1,
1665-
PanicResult::Err(v7) => blk4,
1656+
Match(match_enum(v4) {
1657+
PanicResult::Ok(v5) => blk1,
1658+
PanicResult::Err(v6) => blk4,
16661659
})
16671660

16681661
blk1:
16691662
Statements:
1670-
(v8: core::felt252, v9: core::internal::LoopResult::<core::bool, core::integer::u32>) <- struct_destructure(v6)
1663+
(v7: core::felt252, v8: core::internal::LoopResult::<core::bool, core::integer::u32>) <- struct_destructure(v5)
16711664
End:
1672-
Match(match_enum(v9) {
1673-
LoopResult::Normal(v10) => blk2,
1674-
LoopResult::EarlyReturn(v11) => blk3,
1665+
Match(match_enum(v8) {
1666+
LoopResult::Normal(v9) => blk2,
1667+
LoopResult::EarlyReturn(v10) => blk3,
16751668
})
16761669

16771670
blk2:
16781671
Statements:
1679-
(v12: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
1680-
(v13: core::integer::u32) <- 1
1681-
(v14: (core::integer::u32,)) <- struct_construct(v13)
1682-
(v15: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v14)
1672+
(v11: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
1673+
(v12: core::integer::u32) <- 1
1674+
(v13: (core::integer::u32,)) <- struct_construct(v12)
1675+
(v14: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v13)
16831676
End:
1684-
Return(v3, v12, v15)
1677+
Return(v2, v11, v14)
16851678

16861679
blk3:
16871680
Statements:
1688-
(v16: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
1689-
(v17: (core::integer::u32,)) <- struct_construct(v11)
1690-
(v18: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v17)
1681+
(v15: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
1682+
(v16: (core::integer::u32,)) <- struct_construct(v10)
1683+
(v17: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v16)
16911684
End:
1692-
Return(v3, v16, v18)
1685+
Return(v2, v15, v17)
16931686

16941687
blk4:
16951688
Statements:
1696-
(v19: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Err(v7)
1689+
(v18: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Err(v6)
16971690
End:
1698-
Return(v3, v4, v19)
1691+
Return(v2, v3, v18)
16991692

17001693

17011694
Generated loop lowering for source location:

0 commit comments

Comments
 (0)