Skip to content

Commit

Permalink
_adapt_mutable_param_with_default_value: fixup
Browse files Browse the repository at this point in the history
Better handle enums
  • Loading branch information
pthom committed Nov 15, 2024
1 parent 929e086 commit 13a0147
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,68 +56,18 @@ def _can_access_enum_with_type(
cpp_type_str: str, # A type name being checked (e.g. MyEnum, A::MyEnum)
cpp_enum: CppEnum, # The enum being checked (e.g. "A::MyEnum")
) -> bool:
"""
Check if the current scope can access the enum with the given type name.
Args:
current_scope (CppScope): The current scope in the code.
cpp_type_str (str): The type name being checked.
enum_name_with_scope (str): The fully qualified name of the enum.
Returns:
bool: True if the enum can be accessed with the given type name from the current scope.
"""
# Remove leading "::" for uniformity
is_enum_class = cpp_enum.enum_type == "class"
if not is_enum_class:
return False

enum_name_with_scope = cpp_enum.cpp_scope_str(include_self=True)
"""Check if the current scope can access the enum with the given type name"""
cpp_type_str = cpp_type_str.lstrip(":")
enum_name_with_scope = cpp_enum.cpp_scope_str(include_self=True)

# If cpp_type_str is unqualified, we need to check each scope in the hierarchy
# Generate possible fully qualified names by prepending scopes from innermost to outermost
for scope in current_scope.scope_hierarchy_list:
# Use the qualified_name method to construct the fully qualified name
full_type_name = scope.qualified_name(cpp_type_str)
# Compare with the enum's fully qualified name
if full_type_name == enum_name_with_scope:
return True
return False


def _can_access_enum_with_value(
current_scope: CppScope, # Represent e.g. "A::B::C"
cpp_value_str: str, # A value name being checked (e.g. MyEnum::a, A::MyEnum::a)
cpp_enum: CppEnum # The enum being checked (e.g. "A::MyEnum")
) -> bool:

if cpp_enum.enum_type == "class":
def compute_cpp_type_str() -> str | None:
if "::" in cpp_value_str:
# split and take everything except the last element
r = "::".join(cpp_value_str.split("::")[:-1])
return r
return None
cpp_type_str = compute_cpp_type_str()
if cpp_type_str is None:
return False
return _can_access_enum_with_type(current_scope, cpp_type_str, cpp_enum)
else:
return False # C enum are too shady
# enum_name_with_scope = cpp_enum.cpp_scope_str(include_self=False)
# # Generate possible fully qualified names by prepending scopes from innermost to outermost
# for scope in current_scope.scope_hierarchy_list:
# # Use the qualified_name method to construct the fully qualified name
# full_value_name = scope.qualified_name(cpp_value_str)
# # Compare with the enum's fully qualified name
# if full_value_name.startswith(enum_name_with_scope + "_"):
# return True
# return False




@dataclass
class _ImmutableCallables:
fn_immutables_types: Callable[[str], bool]
Expand All @@ -128,12 +78,6 @@ def _immutable_functions_default(lg_context: LitgenContext, code_scope: CppScope
options = lg_context.options

def _fn_immutables_values(cpp_value_str: str) -> bool:
for cpp_enum in lg_context.encountered_cpp_enums:
if _can_access_enum_with_value(code_scope, cpp_value_str, cpp_enum):
return True
# if cpp_type_str.startswith(enum_type + "_"):
# return True

_fn_immutables_values_user = options.fn_params_adapt_mutable_param_with_default_value__fn_is_known_immutable_value
if _fn_immutables_values_user is not None:
return _fn_immutables_values_user(cpp_value_str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def assert_parents_are_present() -> None:
assert_parents_are_present()

all_adapters_functions = [
adapt_exclude_params, # must be done at start
adapt_mutable_param_with_default_value, # must be done just after adapt_exclude_params
adapt_c_buffers,
adapt_exclude_params,
adapt_c_arrays,
adapt_const_char_pointer_with_default_null,
adapt_modifiable_immutable_to_return,
adapt_modifiable_immutable, # must be done *after* adapt_c_buffers
adapt_c_string_list,
adapt_c_string_list_no_count,
adapt_variadic_format,
adapt_mutable_param_with_default_value,
]
all_adapters_functions += inout_adapted_function.options.fn_custom_adapters

Expand Down
6 changes: 1 addition & 5 deletions src/litgen/tests/internal/context/litgen_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,8 @@ class MyEnumNonClass(enum.Enum):
def foo_inner(
x: int = Inner.FooValue(),
a: MyEnumClass = MyEnumClass.value_a,
b: Optional[MyEnumNonClass] = None
b: MyEnumNonClass = MyEnumNonClass.value_a
) -> int:
"""---
Python bindings defaults:
If b is None, then its default value will be: MyEnumNonClass.value_a
"""
pass
# <submodule inner>
Expand Down
70 changes: 29 additions & 41 deletions src/litgen/tests/litgen_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,55 +56,29 @@ def test_scoping_no_root_namespace():
py::overload_cast<N::EC>(N::Foo), py::arg("e") = N::EC::a);
pyNsN.def("foo",
[](const std::optional<const N::E> & e = std::nullopt) -> N::E
{
auto Foo_adapt_mutable_param_with_default_value = [](const std::optional<const N::E> & e = std::nullopt) -> N::E
{
const N::E& e_or_default = [&]() -> const N::E {
if (e.has_value())
return e.value();
else
return N::E_a;
}();
auto lambda_result = N::Foo(e_or_default);
return lambda_result;
};
return Foo_adapt_mutable_param_with_default_value(e);
},
py::arg("e") = py::none(),
"---\\nPython bindings defaults:\\n If e is None, then its default value will be: N.E.a");
py::overload_cast<N::E>(N::Foo), py::arg("e") = N::E_a);
pyNsN.def("foo",
[](const std::optional<const N::E> & e = std::nullopt, const std::optional<const N::S> & s = std::nullopt) -> N::S
[](N::E e = N::E_a, const std::optional<const N::S> & s = std::nullopt) -> N::S
{
auto Foo_adapt_mutable_param_with_default_value = [](const std::optional<const N::E> & e = std::nullopt, const std::optional<const N::S> & s = std::nullopt) -> N::S
auto Foo_adapt_mutable_param_with_default_value = [](N::E e = N::E_a, const std::optional<const N::S> & s = std::nullopt) -> N::S
{
const N::E& e_or_default = [&]() -> const N::E {
if (e.has_value())
return e.value();
else
return N::E_a;
}();
const N::S& s_or_default = [&]() -> const N::S {
if (s.has_value())
return s.value();
else
return N::S();
}();
auto lambda_result = N::Foo(e_or_default, s_or_default);
auto lambda_result = N::Foo(e, s_or_default);
return lambda_result;
};
return Foo_adapt_mutable_param_with_default_value(e, s);
},
py::arg("e") = py::none(), py::arg("s") = py::none(),
"---\\nPython bindings defaults:\\n If any of the params below is None, then its default value below will be used:\\n e: N.E.a\\n s: N.S()");
py::arg("e") = N::E_a, py::arg("s") = py::none(),
"---\\nPython bindings defaults:\\n If s is None, then its default value will be: N.S()");
} // </namespace N>
""",
)
Expand All @@ -131,24 +105,19 @@ def foo(e: EC = EC.a) -> EC:
pass
@staticmethod
@overload
def foo(e: Optional[E] = None) -> E:
"""---
Python bindings defaults:
If e is None, then its default value will be: N.E.a
"""
def foo(e: E = E.a) -> E:
pass
@staticmethod
@overload
def foo(e: Optional[E] = None, s: Optional[S] = None) -> S:
def foo(e: E = E.a, s: Optional[S] = None) -> S:
"""---
Python bindings defaults:
If any of the params below is None, then its default value below will be used:
e: N.E.a
s: N.S()
If s is None, then its default value will be: N.S()
"""
pass
# </submodule n>
''',
)

Expand Down Expand Up @@ -183,3 +152,22 @@ def __init__(self, e: EC = EC.a) -> None:
pass
'''
)


def test_naming() -> None:
code = """
//namespace CamelCase { // should be converted to snake_case in Python
enum Foo { a, b, c };
void UseFoo(Foo f = Foo::a);
//}
// should have this signature in Python:
// def use_foo(f: camel_case.Foo = camel_case.Foo.a) -> None:
// void UseFoo(CamelCase::Foo f = CamelCase::Foo::a);
"""
options = litgen.LitgenOptions()
options.fn_params_adapt_mutable_param_with_default_value__regex = r".*"
generated_code = litgen.generate_code(options, code)
print(generated_code.stub_code)
23 changes: 12 additions & 11 deletions src/srcmlcpp/cpp_types/functions/cpp_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,11 @@ def seems_mutable_param_with_default_value(
"""Determines whether the parameter with a default value appears to be mutable."""

decl = self.decl

decl_type_str = self.decl.cpp_type.str_code()
if decl_type_str in _BASE_IMMUTABLE_TYPES:
return False
if fn_is_immutable_type is not None:
if fn_is_immutable_type(decl_type_str):
return False


cpp_type = decl.cpp_type

# Bail out if the parameter has no default value
initial_value_code = decl.initial_value_code.strip()
has_default_value = bool(initial_value_code)
if not has_default_value:
if len(initial_value_code) == 0:
return False

# Check for types we cannot handle or that suggest mutability
Expand All @@ -268,9 +260,18 @@ def seems_mutable_param_with_default_value(
if not (is_const_ref or is_value_type):
return False

# Check if the default value looks like a mutable object
if True: # just to have a block
decl_type_str = cpp_type.str_code()
if decl_type_str in _BASE_IMMUTABLE_TYPES:
return False
if fn_is_immutable_type is not None and fn_is_immutable_type(decl_type_str):
return False

# Check if the default value looks like a mutable object
r = _looks_like_mutable_default_value(initial_value_code, fn_is_immutable_type, fn_is_immutable_value)
return r

def __str__(self) -> str:
return self.str_code()

Expand Down

0 comments on commit 13a0147

Please sign in to comment.