Skip to content

Commit

Permalink
run black formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
pthom committed Nov 13, 2024
1 parent 19d467b commit b735d1b
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,17 @@ def make_adapted_lambda_code_end_template_buffer(self) -> str:
throw std::runtime_error("Unsupported dtype");
};
"""
template_intro = _nanobind_buffer_type_to_letter_code + """
template_intro = (
_nanobind_buffer_type_to_letter_code
+ """
// Compute the letter code for the buffer type
uint8_t dtype_code_{template_buffer_name} = {template_buffer_name}.dtype().code;
size_t sizeof_item_{template_buffer_name} = {template_buffer_name}.dtype().bits / 8;
char {template_buffer_name}_type = _nanobind_buffer_type_to_letter_code(dtype_code_{template_buffer_name}, sizeof_item_{template_buffer_name});
// call the correct template version by casting
"""
)

template_loop_type = """
{maybe_else}if ({template_buffer_name}_type == '{pyarray_type_char}')
Expand Down
4 changes: 3 additions & 1 deletion src/litgen/internal/adapted_types/adapted_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,9 @@ def make_pyclass_creation_code() -> str:
)

if self.cpp_element().is_final():
replacements.maybe_py_is_final = ", py::is_final()" if options.bind_library == BindLibraryType.pybind11 else ", nb::is_final()"
replacements.maybe_py_is_final = (
", py::is_final()" if options.bind_library == BindLibraryType.pybind11 else ", nb::is_final()"
)
else:
replacements.maybe_py_is_final = ""

Expand Down
6 changes: 5 additions & 1 deletion src/litgen/internal/adapted_types/adapted_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ def pydef_lines(self) -> list[str]:
# Enum decl first line
is_arithmetic = code_utils.does_match_regex(self.options.enum_make_arithmetic__regex, enum_name_cpp)
if is_arithmetic:
arithmetic_str = ", py::arithmetic()" if self.options.bind_library == BindLibraryType.pybind11 else ", nb::is_arithmetic()"
arithmetic_str = (
", py::arithmetic()"
if self.options.bind_library == BindLibraryType.pybind11
else ", nb::is_arithmetic()"
)
pydef_class_var_parent = cpp_to_python.cpp_scope_to_pybind_parent_var_name(self.options, self.cpp_element())
enum_var = f"auto pyEnum{enum_name_python} = "
py = "py" if self.options.bind_library == BindLibraryType.pybind11 else "nb"
Expand Down
13 changes: 9 additions & 4 deletions src/litgen/internal/adapted_types/adapted_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,10 @@ def _pydef_pyarg_list(self) -> list[str]:
# Skip *args and **kwarg
param_type_cpp = adapted_decl.cpp_element().cpp_type.str_code()
param_type_cpp_simplified = (
param_type_cpp.replace("const ", "").replace("pybind11::", "py::").replace(" &", "").replace("nanobind::", "nb::")
param_type_cpp.replace("const ", "")
.replace("pybind11::", "py::")
.replace(" &", "")
.replace("nanobind::", "nb::")
)
if param_type_cpp_simplified in ["py::args", "py::kwargs", "nb::args", "nb::kwargs"]:
continue
Expand All @@ -899,7 +902,6 @@ def _pydef_pyarg_list(self) -> list[str]:
pyarg_strs.append(pyarg_str)
return pyarg_strs


def _pydef_fill_return_value_policy(self) -> None:
"""Parses the return_value_policy from the function end of line comment
For example:
Expand Down Expand Up @@ -969,7 +971,7 @@ def _pydef_fill_keep_alive_from_function_comment(self) -> str | None:
def _pydef_fill_call_guard_from_function_comment(self) -> str | None:
v_py = self._pydef_fill_call_policy_from_function_comment("py::call_guard")
v_nb = self._pydef_fill_call_policy_from_function_comment("nb::call_guard")
return self._replace_py_or_nb_namespace(v_py or v_nb)
return self._replace_py_or_nb_namespace(v_py or v_nb)

def _pydef_str_parent_cpp_scope(self) -> str:
if self.is_method():
Expand Down Expand Up @@ -1240,7 +1242,10 @@ def _stub_params_list_signature(self) -> list[str]:

# Handle *args and **kwargs
param_type_cpp_simplified = (
param_type_cpp.replace("const ", "").replace("pybind11::", "py::").replace(" &", "").replace("nanobind::", "nb::")
param_type_cpp.replace("const ", "")
.replace("pybind11::", "py::")
.replace(" &", "")
.replace("nanobind::", "nb::")
)
if param_type_cpp_simplified == "py::args" or param_type_cpp_simplified == "nb::args":
r.append("*args")
Expand Down
4 changes: 1 addition & 3 deletions src/litgen/internal/adapted_types/adapted_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def _pydef_def_submodule_code(self) -> list[str]:
return []
self.lg_context.namespaces_pydef.register_namespace_creation(self._qualified_namespace_name())

submodule_code_template = (
'{py}::module_ {submodule_cpp_var} = {parent_module_cpp_var}.def_submodule("{module_name}", "{module_doc}");'
)
submodule_code_template = '{py}::module_ {submodule_cpp_var} = {parent_module_cpp_var}.def_submodule("{module_name}", "{module_doc}");'

replace_tokens = Munch()
replace_tokens.py = "py" if self.options.bind_library == BindLibraryType.pybind11 else "nb"
Expand Down
1 change: 1 addition & 0 deletions src/litgen/internal/cpp_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def py_array_type_to_cpp_type(py_array_type: str) -> str:
"long double": "Float",
}


def nanobind_cpp_type_to_dtype_code_as_uint8(cpp_type: str) -> str:
if cpp_type in _NANOBIND_CPP_TYPE_TO_DTYPE_CODE:
return "static_cast<uint8_t>(nb::dlpack::dtype_code::" + _NANOBIND_CPP_TYPE_TO_DTYPE_CODE[cpp_type] + ")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_nanobind_buffer() -> None:
foo_adapt_c_buffers(buffer);
}, nb::arg("buffer"));
"""
""",
)


Expand Down Expand Up @@ -367,5 +367,5 @@ def test_template_buffer_nanobind():
return foo_adapt_c_buffers(buf, flag);
}, nb::arg("buf"), nb::arg("flag"));
"""
""",
)
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,10 @@ class Color4:
def __init__(self, _rgba: List[int]) -> None:
pass
rgba: np.ndarray # ndarray[type=uint8_t, size=4]
"""
""",
)



def test_adapted_ctor() -> None:
# The constructor for Color4 will be adapted to accept std::array<uint8_t, 4>
code = """
Expand Down

0 comments on commit b735d1b

Please sign in to comment.