Skip to content

Commit

Permalink
Run formatter (black) on whole codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
pthom committed Dec 2, 2024
1 parent eae5946 commit b04f247
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def _nanobind_dtype_code_as_uint8(self, idx_param: int) -> str:
dtype_code = cpp_to_python.nanobind_cpp_type_to_dtype_code_as_uint8(raw_cpp_type)
return dtype_code


def _lambda_input_buffer_standard_check_part(self, idx_param: int) -> str:
_ = self
if self.options.bind_library == BindLibraryType.pybind11:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def use_foo(foo = Foo()):


def _can_access_enum_with_type(
current_scope: CppScope, # Represent e.g. "A::B::C"
cpp_type_str: str, # A type name being checked (e.g. MyEnum, A::MyEnum)
cpp_enum: CppEnum, # The enum being checked (e.g. "A::MyEnum")
) -> bool:
current_scope: CppScope, # Represent e.g. "A::B::C"
cpp_type_str: str, # A type name being checked (e.g. MyEnum, A::MyEnum)
cpp_enum: CppEnum, # The enum being checked (e.g. "A::MyEnum")
) -> bool:
"""Check if the current scope can access the enum with the given type name"""
cpp_type_str = cpp_type_str.lstrip(":")
enum_name_with_scope = cpp_enum.cpp_scope_str(include_self=True)
Expand All @@ -78,7 +78,9 @@ def _immutable_functions_default(lg_context: LitgenContext, code_scope: CppScope
options = lg_context.options

def _fn_immutables_values(cpp_value_str: str) -> bool:
_fn_immutables_values_user = options.fn_params_adapt_mutable_param_with_default_value__fn_is_known_immutable_value
_fn_immutables_values_user = (
options.fn_params_adapt_mutable_param_with_default_value__fn_is_known_immutable_value
)
if _fn_immutables_values_user is not None:
return _fn_immutables_values_user(cpp_value_str)
else:
Expand All @@ -97,21 +99,28 @@ def _fn_immutables_types(cpp_type_str: str) -> bool:
return _ImmutableCallables(_fn_immutables_types, _fn_immutables_values)



def was_mutable_param_with_default_value_made_optional(lg_context: LitgenContext, cpp_param: CppParameter) -> bool:
options = lg_context.options
option_active = options.fn_params_adapt_mutable_param_with_default_value__to_autogenerated_named_ctor
if not option_active:
return False
immutable_callables = _immutable_functions_default(lg_context, cpp_param.cpp_scope(include_self=False))
r = cpp_param.seems_mutable_param_with_default_value(immutable_callables.fn_immutables_types, immutable_callables.fn_immutables_values)
r = cpp_param.seems_mutable_param_with_default_value(
immutable_callables.fn_immutables_types, immutable_callables.fn_immutables_values
)
return r


def adapt_mutable_param_with_default_value(adapted_function: AdaptedFunction) -> Optional[LambdaAdapter]:
options = adapted_function.options
is_autogenerated_named_ctor = "Auto-generated default constructor with named params" in adapted_function.cpp_element().cpp_element_comments.comment_on_previous_lines
apply_because_autogen = options.fn_params_adapt_mutable_param_with_default_value__to_autogenerated_named_ctor and is_autogenerated_named_ctor
is_autogenerated_named_ctor = (
"Auto-generated default constructor with named params"
in adapted_function.cpp_element().cpp_element_comments.comment_on_previous_lines
)
apply_because_autogen = (
options.fn_params_adapt_mutable_param_with_default_value__to_autogenerated_named_ctor
and is_autogenerated_named_ctor
)
match_regex = code_utils.does_match_regex(
options.fn_params_adapt_mutable_param_with_default_value__regex,
adapted_function.cpp_adapted_function.function_name,
Expand All @@ -121,12 +130,16 @@ def adapt_mutable_param_with_default_value(adapted_function: AdaptedFunction) ->

old_function_params: list[CppParameter] = adapted_function.cpp_adapted_function.parameter_list.parameters

immutable_callables = _immutable_functions_default(adapted_function.lg_context, adapted_function.cpp_element().cpp_scope(include_self=False))
immutable_callables = _immutable_functions_default(
adapted_function.lg_context, adapted_function.cpp_element().cpp_scope(include_self=False)
)

def needs_adapt() -> bool:
for old_adapted_param in adapted_function.adapted_parameters():
cpp_param = old_adapted_param.cpp_element()
if cpp_param.seems_mutable_param_with_default_value(immutable_callables.fn_immutables_types, immutable_callables.fn_immutables_values):
if cpp_param.seems_mutable_param_with_default_value(
immutable_callables.fn_immutables_types, immutable_callables.fn_immutables_values
):
return True
return False

Expand Down Expand Up @@ -161,7 +174,9 @@ def _fn_optional_to_const_noref(cpp_type: CppType) -> str:

for old_param in old_function_params:
was_replaced = False
if old_param.seems_mutable_param_with_default_value(immutable_callables.fn_immutables_types, immutable_callables.fn_immutables_values):
if old_param.seems_mutable_param_with_default_value(
immutable_callables.fn_immutables_types, immutable_callables.fn_immutables_values
):
was_replaced = True

param_name = old_param.decl.decl_name
Expand All @@ -171,8 +186,10 @@ def _fn_optional_to_const_noref(cpp_type: CppType) -> str:
# (where std::optional<T> is replaced by T)
param_or__name = f"{param_name}_or_default"
param_or__type_noref = _fn_optional_to_const_noref(param_type)
param_or__real_default = cpp_to_python.var_value_to_python(adapted_function.lg_context, old_param.decl.initial_value_code)
new_function_comment_lines[param_name] = param_or__real_default
param_or__real_default = cpp_to_python.var_value_to_python(
adapted_function.lg_context, old_param.decl.initial_value_code
)
new_function_comment_lines[param_name] = param_or__real_default

# Create new calling param (std::optional<T>)
new_param = copy.deepcopy(old_param)
Expand Down Expand Up @@ -218,7 +235,9 @@ def _fn_optional_to_const_noref(cpp_type: CppType) -> str:
comment = f" If {param_name} is None, then its default value will be: {param_default_value}"
cpp_comments.add_comment_on_previous_lines(comment)
else:
cpp_comments.add_comment_on_previous_lines(" If any of the params below is None, then its default value below will be used:")
cpp_comments.add_comment_on_previous_lines(
" If any of the params below is None, then its default value below will be used:"
)
for param_name, param_default_value in new_function_comment_lines.items():
comment = f" {param_name}: {param_default_value}"
cpp_comments.add_comment_on_previous_lines(comment)
Expand Down
12 changes: 8 additions & 4 deletions src/litgen/internal/adapt_function_params/apply_all_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def assert_parents_are_present() -> None:
assert_parents_are_present()

all_adapters_functions = [
adapt_c_buffers, # must be done at start
adapt_c_buffers, # must be done at start
adapt_exclude_params,
adapt_mutable_param_with_default_value, # must be done just after adapt_exclude_params
adapt_c_arrays,
Expand Down Expand Up @@ -125,15 +125,19 @@ def _apply_all_adapters_on_constructor(inout_adapted_function: AdaptedFunction)
cpp_wrapper_function = srcmlcpp_main.code_first_function_decl(
inout_adapted_function.options.srcmlcpp_options, ctor_wrapper_signature_code
)
cpp_wrapper_function.cpp_element_comments.comment_on_previous_lines = inout_adapted_function.cpp_element().cpp_element_comments.comment_on_previous_lines
cpp_wrapper_function.cpp_element_comments.comment_on_previous_lines = (
inout_adapted_function.cpp_element().cpp_element_comments.comment_on_previous_lines
)
cpp_wrapper_function.parent = inout_adapted_function.cpp_element().parent
ctor_adapted_wrapper_function = AdaptedFunction(
inout_adapted_function.lg_context,
cpp_wrapper_function,
is_overloaded=False,
initial_lambda_to_call="ctor_wrapper",
)
inout_adapted_function.cpp_element().cpp_element_comments.comment_on_previous_lines = cpp_wrapper_function.cpp_element_comments.comment_on_previous_lines
inout_adapted_function.cpp_element().cpp_element_comments.comment_on_previous_lines = (
cpp_wrapper_function.cpp_element_comments.comment_on_previous_lines
)

if ctor_adapted_wrapper_function.cpp_adapter_code is not None:
inout_adapted_function.cpp_adapter_code = (
Expand Down Expand Up @@ -164,7 +168,7 @@ def _make_adapted_lambda_code_end(adapted_function: AdaptedFunction, lambda_adap
_return_referenced = False

if hasattr(adapted_function.cpp_element(), "return_type"):
_return_referenced = '&' in adapted_function.cpp_element().return_type.modifiers
_return_referenced = "&" in adapted_function.cpp_element().return_type.modifiers

# Fill auto_r_equal_or_void
if _fn_return_type != "void":
Expand Down
6 changes: 4 additions & 2 deletions src/litgen/internal/adapted_types/adapted_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,6 @@ def pydef_code(self) -> str:
# (we keep a backup of the original, to be able to compare)
ctor_decl_adapted = adapted_ctor.cpp_adapted_function


if len(ctor_decl.parameter_list.parameters) == 0:
py = "py" if self.options.bind_library == BindLibraryType.pybind11 else "nb"
return f"{_i_}.def({py}::init<>()) // implicit default constructor \n"
Expand Down Expand Up @@ -1352,7 +1351,10 @@ def pydef_code(self) -> str:
replacements_lines.maybe_pyargs = ", ".join(adapted_ctor._pydef_pyarg_list())

def get_all_params_set_values() -> str:
from litgen.internal.adapt_function_params._adapt_mutable_param_with_default_value import was_mutable_param_with_default_value_made_optional
from litgen.internal.adapt_function_params._adapt_mutable_param_with_default_value import (
was_mutable_param_with_default_value_made_optional,
)

original_parameters = ctor_decl.parameter_list.parameters
modified_parameters = ctor_decl_adapted.parameter_list.parameters
# Remove first "self" parameter from modified_parameters (may have been added by AdaptedFunction, if using nanobind)
Expand Down
4 changes: 2 additions & 2 deletions src/litgen/internal/adapted_types/adapted_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,10 +942,10 @@ def find_return_policy_in_comments(rv_policy_token: str) -> str | None:

if (matches_regex_pointer and returns_pointer) or (matches_regex_reference and returns_reference):
self.return_value_policy = "reference"

if options.fn_return_force_policy_reference__callback:
options.fn_return_force_policy_reference__callback(self)

def _pydef_fill_call_policy_from_function_comment(self, call_policy_token: str) -> str | None:
function_comment = self.cpp_element().cpp_element_comments.comments_as_str()
if call_policy_token in function_comment:
Expand Down
98 changes: 49 additions & 49 deletions src/litgen/internal/cpp_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,48 +48,49 @@ def comment_pydef_one_line(options: LitgenOptions, title_cpp: str) -> str:


def _split_cpp_type_template_args(s: str) -> list[str]:
"""Split template at level 1,
ie
"AAA <BBB<CC>DD> EE" -> ["AAA ", "BBB<CC>DD", " EE"]
"""
idx_start = s.find('<')
if idx_start == -1:
# No '<' found, return the original string in a list
return [s]

# Initialize depth and index
depth = 0
idx = idx_start
idx_end = -1
while idx < len(s):
if s[idx] == '<':
depth += 1
elif s[idx] == '>':
depth -= 1
if depth == 0:
idx_end = idx
break
idx += 1

if idx_end == -1:
# No matching '>' found, return the original string in a list
return [s]

# Split the string into before, inside, and after
before = s[:idx_start]
inside = s[idx_start + 1: idx_end] # Exclude the outer '<' and '>'
after = s[idx_end + 1:]
# Reconstruct 'inside' with the outer '<' and '>' if needed
# For now, as per your requirement, we exclude them.

# Since you want 'inside' to include nested templates, we don't split further
return [before, inside, after]
"""Split template at level 1,
ie
"AAA <BBB<CC>DD> EE" -> ["AAA ", "BBB<CC>DD", " EE"]
"""
idx_start = s.find("<")
if idx_start == -1:
# No '<' found, return the original string in a list
return [s]

# Initialize depth and index
depth = 0
idx = idx_start
idx_end = -1
while idx < len(s):
if s[idx] == "<":
depth += 1
elif s[idx] == ">":
depth -= 1
if depth == 0:
idx_end = idx
break
idx += 1

if idx_end == -1:
# No matching '>' found, return the original string in a list
return [s]

# Split the string into before, inside, and after
before = s[:idx_start]
inside = s[idx_start + 1 : idx_end] # Exclude the outer '<' and '>'
after = s[idx_end + 1 :]
# Reconstruct 'inside' with the outer '<' and '>' if needed
# For now, as per your requirement, we exclude them.

# Since you want 'inside' to include nested templates, we don't split further
return [before, inside, after]


def _perform_cpp_type_replacements_recursively(cpp_type_str: str, type_replacements: RegexReplacementList) -> str:
# Preprocessing: Remove 'const', '&', and '*' tokens
import re

if '<' not in cpp_type_str:
if "<" not in cpp_type_str:
return type_replacements.apply(cpp_type_str)

# Split at the first level of template arguments
Expand All @@ -108,14 +109,14 @@ def _perform_cpp_type_replacements_recursively(cpp_type_str: str, type_replaceme
r = type_replacements.apply(r)

# Replace angle brackets with square brackets as last resort (this means some template regexes are missing)
r = r.replace('<', '[').replace('>', ']')
r = r.replace("<", "[").replace(">", "]")
# Normalize whitespace
r = re.sub(r'\s+', ' ', r).strip()
r = re.sub(r"\s+", " ", r).strip()

cpp_type_str = re.sub(r'\bconst\b', '', cpp_type_str)
cpp_type_str = re.sub(r'&', '', cpp_type_str)
cpp_type_str = re.sub(r'\*', '', cpp_type_str)
cpp_type_str = re.sub(r'\s+', ' ', cpp_type_str).strip()
cpp_type_str = re.sub(r"\bconst\b", "", cpp_type_str)
cpp_type_str = re.sub(r"&", "", cpp_type_str)
cpp_type_str = re.sub(r"\*", "", cpp_type_str)
cpp_type_str = re.sub(r"\s+", " ", cpp_type_str).strip()

return r

Expand All @@ -138,7 +139,9 @@ def type_to_python(lg_context: LitgenContext, cpp_type_str: str) -> str:

def normalize_whitespace(s: str) -> str:
import re
return re.sub(r'\s+', ' ', s).strip()

return re.sub(r"\s+", " ", s).strip()

r = normalize_whitespace(r)

# Fix for std::optional (issue origin unknown)
Expand Down Expand Up @@ -231,12 +234,9 @@ def var_value_to_python(lg_context: LitgenContext, default_value_cpp: str) -> st
for cpp_namespace_name in options.namespaces_root:
r = r.replace(cpp_namespace_name + ".", "")


# If this default value uses a bound template type, try to translate it
specialized_type_python_default_value = (
options.class_template_options.specialized_type_python_default_value(
default_value_cpp, lg_context.options.type_replacements
)
specialized_type_python_default_value = options.class_template_options.specialized_type_python_default_value(
default_value_cpp, lg_context.options.type_replacements
)
if specialized_type_python_default_value is not None:
r = specialized_type_python_default_value
Expand Down
4 changes: 2 additions & 2 deletions src/litgen/tests/brace_init_default_value_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def test_fn_brace():
)
code_utils.assert_are_codes_equal(
generated_code.stub_code,
'''
"""
def f(v: V = V(1, 2)) -> None:
pass
'''
""",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from srcmlcpp.cpp_types import CppFunctionDecl

import litgen
from litgen import litgen_generator


@dataclass
class AdaptedFunction2(CppFunctionDecl):
Expand Down Expand Up @@ -75,4 +75,4 @@ class MyClass {
}, py::arg("arr"))
;
""",
)
)
Loading

0 comments on commit b04f247

Please sign in to comment.