Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cairo-lang-lowering/src/add_withdraw_gas/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ fn create_panic_block<'db>(
with_coupon: false,
outputs: vec![never_var],
location,
is_specialization_base_call: false,
})],
end: BlockEnd::Match {
info: MatchInfo::Enum(MatchEnumInfo {
Expand Down
3 changes: 3 additions & 0 deletions crates/cairo-lang-lowering/src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ struct StatementCallCached {
with_coupon: bool,
outputs: Vec<usize>,
location: LocationIdCached,
is_specialization_base_call: bool,
}
impl StatementCallCached {
fn new<'db>(stmt: StatementCall<'db>, ctx: &mut CacheSavingContext<'db>) -> Self {
Expand All @@ -997,6 +998,7 @@ impl StatementCallCached {
with_coupon: stmt.with_coupon,
outputs: stmt.outputs.iter().map(|var| var.index()).collect(),
location: LocationIdCached::new(stmt.location, ctx),
is_specialization_base_call: stmt.is_specialization_base_call,
}
}
fn embed<'db>(self, ctx: &mut CacheLoadingContext<'db>) -> StatementCall<'db> {
Expand All @@ -1010,6 +1012,7 @@ impl StatementCallCached {
.map(|var_id| ctx.lowered_variables_id[var_id])
.collect(),
location: self.location.embed(ctx),
is_specialization_base_call: self.is_specialization_base_call,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/cairo-lang-lowering/src/destructs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ pub fn add_destructs<'db>(
with_coupon: false,
outputs: vec![output_var],
location: variables.variables[plain_destruct.var_id].location,
is_specialization_base_call: false,
})
}

Expand All @@ -415,6 +416,7 @@ pub fn add_destructs<'db>(
with_coupon: false,
outputs: vec![new_panic_var, output_var],
location,
is_specialization_base_call: false,
});
last_panic_var = new_panic_var;
}
Expand Down
16 changes: 13 additions & 3 deletions crates/cairo-lang-lowering/src/inline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,18 @@ pub fn priv_should_inline<'db>(
if db.priv_never_inline(function_id)? {
return Ok(false);
}

// Breaks cycles.
if db.concrete_in_cycle(function_id, DependencyType::Call, LoweringStage::Monomorphized)? {
// Prevents inlining of functions that may call themselves, by checking if the base of the
// function (the function without specialization) is in a call cycle (we cannot use the
// specialized version as it may call the base function with different specialization which may
// cause long inlining chains).
// TODO(Tomerstarkware): allow inlining of specialized recursive functions for one level of
// recursion.
let base = match function_id.long(db) {
ConcreteFunctionWithBodyLongId::Semantic(_)
| ConcreteFunctionWithBodyLongId::Generated(_) => function_id,
ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
};
if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
return Ok(false);
}

Expand Down Expand Up @@ -342,6 +351,7 @@ where
if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
calling_function_id.long(db)
&& specialized.base == called_func
&& stmt.is_specialization_base_call
{
// A specialized function should always inline its base.
return Ok(Some((stmt, called_func)));
Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-lowering/src/lower/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ impl<'db> Call<'db> {
with_coupon,
outputs,
location: self.location,
is_specialization_base_call: false,
}));

CallResult { returns, extra_outputs }
Expand Down
12 changes: 7 additions & 5 deletions crates/cairo-lang-lowering/src/lower/specialized_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use cairo_lang_debug::DebugWithDb;
use cairo_lang_semantic::test_utils::setup_test_function;
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::try_extract_matches;
use salsa::Setter;

use crate::db::{LoweringGroup, lowering_group_input};
Expand Down Expand Up @@ -56,17 +57,18 @@ fn test_specialized_function(
"Got diagnostics for the caller {semantic_diagnostics}\n{lowering_diagnostics}."
);
});
let Some(Statement::Call(call)) = lowered_caller
let Some(specialized_id) = lowered_caller
.blocks
.iter()
.flat_map(|(_, b)| b.statements.iter())
.rfind(|statement| matches!(statement, Statement::Call(_)))
.filter_map(|statement| {
try_extract_matches!(statement, Statement::Call)
.and_then(|call| call.function.body(db).ok().flatten())
})
.next_back()
else {
panic!("Could not find the last call in the caller function.");
};
let Ok(Some(specialized_id)) = call.function.body(db) else {
panic!("Expected function body, got: {}", call.function.full_path(db));
};
let lowered_specialized =
db.lowered_body(specialized_id, LoweringStage::Monomorphized).unwrap();
let lowered_formatter = LoweredFormatter::new(db, &lowered_caller.variables);
Expand Down
155 changes: 74 additions & 81 deletions crates/cairo-lang-lowering/src/lower/test_data/loop
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,27 @@ Final lowering:
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::felt252) <- 5
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-149](v0, v1, v2)
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-149]{5, }(v0, v1)
End:
Match(match_enum(v5) {
PanicResult::Ok(v6) => blk1,
PanicResult::Err(v7) => blk2,
Match(match_enum(v4) {
PanicResult::Ok(v5) => blk1,
PanicResult::Err(v6) => blk2,
})

blk1:
Statements:
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
(v9: core::felt252, v10: core::bool) <- struct_destructure(v6)
(v11: (core::bool,)) <- struct_construct(v10)
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v11)
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v8: core::felt252, v9: core::bool) <- struct_destructure(v5)
(v10: (core::bool,)) <- struct_construct(v9)
(v11: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v10)
End:
Return(v3, v8, v12)
Return(v2, v7, v11)

blk2:
Statements:
(v13: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v7)
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v6)
End:
Return(v3, v4, v13)
Return(v2, v3, v12)


Generated loop lowering for source location:
Expand Down Expand Up @@ -198,28 +197,27 @@ Final lowering:
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::felt252) <- 5
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-130](v0, v1, v2)
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-130]{5, }(v0, v1)
End:
Match(match_enum(v5) {
PanicResult::Ok(v6) => blk1,
PanicResult::Err(v7) => blk2,
Match(match_enum(v4) {
PanicResult::Ok(v5) => blk1,
PanicResult::Err(v6) => blk2,
})

blk1:
Statements:
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
(v9: core::felt252, v10: core::bool) <- struct_destructure(v6)
(v11: (core::bool,)) <- struct_construct(v10)
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v11)
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v8: core::felt252, v9: core::bool) <- struct_destructure(v5)
(v10: (core::bool,)) <- struct_construct(v9)
(v11: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v10)
End:
Return(v3, v8, v12)
Return(v2, v7, v11)

blk2:
Statements:
(v13: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v7)
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v6)
End:
Return(v3, v4, v13)
Return(v2, v3, v12)


Generated loop lowering for source location:
Expand Down Expand Up @@ -521,28 +519,27 @@ Final lowering:
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::felt252) <- 5
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-201](v0, v1, v2)
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::bool)>) <- test::foo[38-201]{5, }(v0, v1)
End:
Match(match_enum(v5) {
PanicResult::Ok(v6) => blk1,
PanicResult::Err(v7) => blk2,
Match(match_enum(v4) {
PanicResult::Ok(v5) => blk1,
PanicResult::Err(v6) => blk2,
})

blk1:
Statements:
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
(v9: core::felt252, v10: core::bool) <- struct_destructure(v6)
(v11: (core::bool,)) <- struct_construct(v10)
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v11)
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v8: core::felt252, v9: core::bool) <- struct_destructure(v5)
(v10: (core::bool,)) <- struct_construct(v9)
(v11: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Ok(v10)
End:
Return(v3, v8, v12)
Return(v2, v7, v11)

blk2:
Statements:
(v13: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v7)
(v12: core::panics::PanicResult::<(core::bool,)>) <- PanicResult::Err(v6)
End:
Return(v3, v4, v13)
Return(v2, v3, v12)


Generated loop lowering for source location:
Expand Down Expand Up @@ -1394,29 +1391,27 @@ Final lowering:
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::felt252) <- 0
(v3: test::A) <- struct_construct(v2)
(v4: core::RangeCheck, v5: core::gas::GasBuiltin, v6: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[58-134](v0, v1, v2, v3)
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[58-134]{0, { 0: core::felt252 }, }(v0, v1)
End:
Match(match_enum(v6) {
PanicResult::Ok(v7) => blk1,
PanicResult::Err(v8) => blk2,
Match(match_enum(v4) {
PanicResult::Ok(v5) => blk1,
PanicResult::Err(v6) => blk2,
})

blk1:
Statements:
(v9: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v5)
(v10: ()) <- struct_construct()
(v11: ((),)) <- struct_construct(v10)
(v12: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v11)
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v8: ()) <- struct_construct()
(v9: ((),)) <- struct_construct(v8)
(v10: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v9)
End:
Return(v4, v9, v12)
Return(v2, v7, v10)

blk2:
Statements:
(v13: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v8)
(v11: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v6)
End:
Return(v4, v5, v13)
Return(v2, v3, v11)


Generated loop lowering for source location:
Expand Down Expand Up @@ -1561,28 +1556,27 @@ Final lowering:
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::integer::u8) <- 0
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[54-175](v0, v1, v2)
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[54-175]{0, }(v0, v1)
End:
Match(match_enum(v5) {
PanicResult::Ok(v6) => blk1,
PanicResult::Err(v7) => blk2,
Match(match_enum(v4) {
PanicResult::Ok(v5) => blk1,
PanicResult::Err(v6) => blk2,
})

blk1:
Statements:
(v8: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
(v9: ()) <- struct_construct()
(v10: ((),)) <- struct_construct(v9)
(v11: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v10)
(v7: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v8: ()) <- struct_construct()
(v9: ((),)) <- struct_construct(v8)
(v10: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v9)
End:
Return(v3, v8, v11)
Return(v2, v7, v10)

blk2:
Statements:
(v12: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v7)
(v11: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v6)
End:
Return(v3, v4, v12)
Return(v2, v3, v11)

//! > lowering_diagnostics

Expand Down Expand Up @@ -1657,45 +1651,44 @@ Final lowering:
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::felt252) <- 5
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::felt252, core::internal::LoopResult::<core::bool, core::integer::u32>)>) <- test::foo[37-253](v0, v1, v2)
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(core::felt252, core::internal::LoopResult::<core::bool, core::integer::u32>)>) <- test::foo[37-253]{5, }(v0, v1)
End:
Match(match_enum(v5) {
PanicResult::Ok(v6) => blk1,
PanicResult::Err(v7) => blk4,
Match(match_enum(v4) {
PanicResult::Ok(v5) => blk1,
PanicResult::Err(v6) => blk4,
})

blk1:
Statements:
(v8: core::felt252, v9: core::internal::LoopResult::<core::bool, core::integer::u32>) <- struct_destructure(v6)
(v7: core::felt252, v8: core::internal::LoopResult::<core::bool, core::integer::u32>) <- struct_destructure(v5)
End:
Match(match_enum(v9) {
LoopResult::Normal(v10) => blk2,
LoopResult::EarlyReturn(v11) => blk3,
Match(match_enum(v8) {
LoopResult::Normal(v9) => blk2,
LoopResult::EarlyReturn(v10) => blk3,
})

blk2:
Statements:
(v12: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
(v13: core::integer::u32) <- 1
(v14: (core::integer::u32,)) <- struct_construct(v13)
(v15: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v14)
(v11: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v12: core::integer::u32) <- 1
(v13: (core::integer::u32,)) <- struct_construct(v12)
(v14: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v13)
End:
Return(v3, v12, v15)
Return(v2, v11, v14)

blk3:
Statements:
(v16: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
(v17: (core::integer::u32,)) <- struct_construct(v11)
(v18: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v17)
(v15: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v3)
(v16: (core::integer::u32,)) <- struct_construct(v10)
(v17: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Ok(v16)
End:
Return(v3, v16, v18)
Return(v2, v15, v17)

blk4:
Statements:
(v19: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Err(v7)
(v18: core::panics::PanicResult::<(core::integer::u32,)>) <- PanicResult::Err(v6)
End:
Return(v3, v4, v19)
Return(v2, v3, v18)


Generated loop lowering for source location:
Expand Down
Loading