Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback to customize the base classes used in generated bindings (merged) #27

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/litgen/internal/adapted_types/adapted_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,22 @@ def stub_lines(self) -> list[str]:

def str_parent_classes_python() -> str:
parents: list[str] = []
if not self.cpp_element().has_base_classes():
custom_derived = (
[] if not self.options.class_base_custom_derivation__callback
else self.options.class_base_custom_derivation__callback(self, True))

if not custom_derived and not self.cpp_element().has_base_classes():
return ""
for _access_type, base_class in self.cpp_element().base_classes():
class_python_scope = cpp_to_python.cpp_scope_to_pybind_scope_str(
self.options, base_class, include_self=True
)
parents.append(class_python_scope)

if custom_derived:
for custom_base in custom_derived:
parents.append(custom_base)
else:
for _access_type, base_class in self.cpp_element().base_classes():
class_python_scope = cpp_to_python.cpp_scope_to_pybind_scope_str(
self.options, base_class, include_self=True
)
parents.append(class_python_scope)
if len(parents) == 0:
return ""
else:
Expand Down Expand Up @@ -604,11 +613,19 @@ def make_pyclass_creation_code() -> str:

# fill py::class_ additional template params (base classes, nodelete, etc)
other_template_params_list = []
if self.cpp_element().has_base_classes():
custom_derived = (
[] if not self.options.class_base_custom_derivation__callback
else self.options.class_base_custom_derivation__callback(self, False))

if custom_derived:
for custom_base in custom_derived:
other_template_params_list.append(custom_base)
elif self.cpp_element().has_base_classes():
base_classes = self.cpp_element().base_classes()
for access_type, base_class in base_classes:
if access_type == CppAccessType.public or access_type == CppAccessType.protected:
other_template_params_list.append(base_class.cpp_scope_str(include_self=True))

if self.cpp_element().has_private_destructor() and options.bind_library == BindLibraryType.pybind11:
# nanobind does not support nodelete
other_template_params_list.append(f"std::unique_ptr<{qualified_struct_name}, py::nodelete>")
Expand Down
5 changes: 5 additions & 0 deletions src/litgen/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,11 @@ class LitgenOptions:
# the generated classes together
class_template_decorate_in_stub: bool = True

# This callback Callback to customize the base classes used in generated bindings
# First param is the AdoptedClass
# Second indicates context - True for python stub, False for CPP bindings
class_base_custom_derivation__callback: Callable[[Any, bool], list[str]] | None = None

# ------------------------------------------------------------------------------
# Adapt class members
# ------------------------------------------------------------------------------
Expand Down
83 changes: 83 additions & 0 deletions src/litgen/tests/option_custom_class_derive_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations
from codemanip import code_utils
import litgen

def test_custom_classes_base_option():
"""
Example of how the callback mechanism could be used in practice to handle reference return policies
"""

def handle_classes_base(cls, for_python_stub):

bases = []

elem = cls.cpp_element()

if elem.class_name == "SecondClass":
bases.append("FirstClass" if for_python_stub else "CustomNS::FirstClass")

return bases

"""
## First class is in another file, and won't be handled by litgen for another files

namespace CustomNS {
class FirstClass {
public:
FirstClass();
private:
int _value1;
};
}
"""

code = """
class SecondClass : CustomNS::FirstClass {
public:
SecondClass();
private:
int _value2;
};

class ThirdClass : SecondClass {
public:
ThirdClass();
private:
int _value3;
};
"""

options = litgen.LitgenOptions()
options.class_base_custom_derivation__callback = handle_classes_base
generated_code = litgen.generate_code(options, code)

code_utils.assert_are_codes_equal(
generated_code.pydef_code,
"""
auto pyClassSecondClass =
py::class_<SecondClass, CustomNS::FirstClass>
(m, "SecondClass", "")
.def(py::init<>())
;


auto pyClassThirdClass =
py::class_<ThirdClass, SecondClass>
(m, "ThirdClass", "")
.def(py::init<>())
;
""",
)

code_utils.assert_are_codes_equal(
generated_code.stub_code,
"""
class SecondClass(FirstClass):
def __init__(self) -> None:
pass

class ThirdClass(SecondClass):
def __init__(self) -> None:
pass
""",
)