Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions crates/cairo-lang-lowering/src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use cairo_lang_semantic::corelib::CorelibSemantic;
use cairo_lang_semantic::items::functions::{FunctionsSemantic, ImplGenericFunctionId};
use cairo_lang_semantic::items::imp::ImplLongId;
use cairo_lang_semantic::items::structure::StructSemantic;
use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeLongId};
use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeId, TypeLongId};
use cairo_lang_syntax::node::ast::ExprPtr;
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{TypedStablePtr, ast};
Expand Down Expand Up @@ -491,19 +491,25 @@ impl<'db> SpecializedFunction<'db> {
}
}
SpecializationArg::Struct(specialization_args) => {
let ty = param.ty;
let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) = ty.long(db)
else {
unreachable!("Expected a concrete struct type");
};
let Ok(inner_param) = db.concrete_struct_members(*concrete_struct) else {
continue;
// Get element types based on the actual type.
let element_types: Vec<TypeId<'db>> = match param.ty.long(db) {
TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
let Ok(members) = db.concrete_struct_members(*concrete_struct) else {
continue;
};
members.values().map(|member| member.ty).collect()
}
TypeLongId::Tuple(element_types) => element_types.clone(),
TypeLongId::FixedSizeArray { type_id, .. } => {
vec![*type_id; specialization_args.len()]
}
_ => unreachable!("Expected a struct, tuple, or fixed-size array type"),
};
for ((_, inner_param), arg) in
zip_eq(inner_param.iter().rev(), specialization_args.iter().rev())
for (elem_ty, arg) in
zip_eq(element_types.iter().rev(), specialization_args.iter().rev())
{
let lowered_param =
LoweredParam { ty: inner_param.ty, stable_ptr: param.stable_ptr };
LoweredParam { ty: *elem_ty, stable_ptr: param.stable_ptr };
stack.push((lowered_param, arg));
}
}
Expand Down
311 changes: 311 additions & 0 deletions crates/cairo-lang-lowering/src/lower/test_data/specialized
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,314 @@ Statements:
(v3: core::felt252) <- test::bar1(v0, v2)
End:
Return(v3)

//! > ==========================================================================

//! > Test tuple with const values.

//! > test_runner_name
test_specialized_function

//! > function_code
fn foo() {
bar(false, (1, 2, 3))
}

//! > function_name
foo

//! > module_code
fn bar(keep: bool, t: (felt252, felt252, felt252)) {
if keep {
bar_ext(t);
}
}

extern fn bar_ext(t: (felt252, felt252, felt252)) nopanic;

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > caller_lowering
Parameters:
blk0 (root):
Statements:
() <- test::bar{bool::False({}), { 1, 2, 3 }, }()
End:
Return()

//! > specialized_lowering
Parameters:
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v0: core::bool) <- bool::False(v1)
(v3: core::felt252) <- 1
(v4: core::felt252) <- 2
(v5: core::felt252) <- 3
(v2: (core::felt252, core::felt252, core::felt252)) <- struct_construct(v3, v4, v5)
(v6: ()) <- test::bar(v0, v2)
End:
Return(v6)

//! > ==========================================================================

//! > Test tuple with partial specialization.

//! > test_runner_name
test_specialized_function

//! > function_code
fn foo(b: felt252) {
bar(false, (1, b, 3))
}

//! > function_name
foo

//! > module_code
fn bar(keep: bool, t: (felt252, felt252, felt252)) {
if keep {
bar_ext(t);
}
}

extern fn bar_ext(t: (felt252, felt252, felt252)) nopanic;

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > caller_lowering
Parameters: v0: core::felt252
blk0 (root):
Statements:
() <- test::bar{bool::False({}), { 1, NotSpecialized, 3 }, }(v0)
End:
Return()

//! > specialized_lowering
Parameters: v4: core::felt252
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v0: core::bool) <- bool::False(v1)
(v3: core::felt252) <- 1
(v5: core::felt252) <- 3
(v2: (core::felt252, core::felt252, core::felt252)) <- struct_construct(v3, v4, v5)
(v6: ()) <- test::bar(v0, v2)
End:
Return(v6)

//! > ==========================================================================

//! > Test fixed size array with const values.

//! > test_runner_name
test_specialized_function

//! > function_code
fn foo() {
bar(false, [1, 2, 3])
}

//! > function_name
foo

//! > module_code
fn bar(keep: bool, arr: [felt252; 3]) {
if keep {
bar_ext(arr);
}
}

extern fn bar_ext(arr: [felt252; 3]) nopanic;

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > caller_lowering
Parameters:
blk0 (root):
Statements:
() <- test::bar{bool::False({}), { 1, 2, 3 }, }()
End:
Return()

//! > specialized_lowering
Parameters:
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v0: core::bool) <- bool::False(v1)
(v3: core::felt252) <- 1
(v4: core::felt252) <- 2
(v5: core::felt252) <- 3
(v2: [core::felt252; 3]) <- struct_construct(v3, v4, v5)
(v6: ()) <- test::bar(v0, v2)
End:
Return(v6)

//! > ==========================================================================

//! > Test fixed size array with partial specialization.

//! > test_runner_name
test_specialized_function

//! > function_code
fn foo(b: felt252) {
bar(false, [1, b, 3])
}

//! > function_name
foo

//! > module_code
fn bar(keep: bool, arr: [felt252; 3]) {
if keep {
bar_ext(arr);
}
}

extern fn bar_ext(arr: [felt252; 3]) nopanic;

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > caller_lowering
Parameters: v0: core::felt252
blk0 (root):
Statements:
() <- test::bar{bool::False({}), { 1, NotSpecialized, 3 }, }(v0)
End:
Return()

//! > specialized_lowering
Parameters: v4: core::felt252
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v0: core::bool) <- bool::False(v1)
(v3: core::felt252) <- 1
(v5: core::felt252) <- 3
(v2: [core::felt252; 3]) <- struct_construct(v3, v4, v5)
(v6: ()) <- test::bar(v0, v2)
End:
Return(v6)

//! > ==========================================================================

//! > Test nested tuple inside struct.

//! > test_runner_name
test_specialized_function

//! > function_code
fn foo(x: felt252) {
bar(false, S { t: (1, x), a: 3 })
}

//! > function_name
foo

//! > module_code
#[derive(Drop)]
struct S {
t: (felt252, felt252),
a: felt252,
}

fn bar(keep: bool, s: S) {
if keep {
bar_ext(s);
}
}

extern fn bar_ext(s: S) nopanic;

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > caller_lowering
Parameters: v0: core::felt252
blk0 (root):
Statements:
() <- test::bar{bool::False({}), { { 1, NotSpecialized }, 3 }, }(v0)
End:
Return()

//! > specialized_lowering
Parameters: v6: core::felt252
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v0: core::bool) <- bool::False(v1)
(v5: core::felt252) <- 1
(v3: (core::felt252, core::felt252)) <- struct_construct(v5, v6)
(v4: core::felt252) <- 3
(v2: test::S) <- struct_construct(v3, v4)
(v7: ()) <- test::bar(v0, v2)
End:
Return(v7)

//! > ==========================================================================

//! > Test fixed size array inside struct.

//! > test_runner_name
test_specialized_function

//! > function_code
fn foo(x: felt252) {
bar(false, S { arr: [1, x, 3], a: 4 })
}

//! > function_name
foo

//! > module_code
#[derive(Drop)]
struct S {
arr: [felt252; 3],
a: felt252,
}

fn bar(keep: bool, s: S) {
if keep {
bar_ext(s);
}
}

extern fn bar_ext(s: S) nopanic;

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > caller_lowering
Parameters: v0: core::felt252
blk0 (root):
Statements:
() <- test::bar{bool::False({}), { { 1, NotSpecialized, 3 }, 4 }, }(v0)
End:
Return()

//! > specialized_lowering
Parameters: v6: core::felt252
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v0: core::bool) <- bool::False(v1)
(v5: core::felt252) <- 1
(v7: core::felt252) <- 3
(v3: [core::felt252; 3]) <- struct_construct(v5, v6, v7)
(v4: core::felt252) <- 4
(v2: test::S) <- struct_construct(v3, v4)
(v8: ()) <- test::bar(v0, v2)
End:
Return(v8)
Loading