Skip to content

Commit

Permalink
added callback and test
Browse files Browse the repository at this point in the history
  • Loading branch information
jnastarot authored and pthom committed Dec 2, 2024
1 parent b04f247 commit f5fb5c0
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 7 deletions.
31 changes: 24 additions & 7 deletions src/litgen/internal/adapted_types/adapted_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,22 @@ def stub_lines(self) -> list[str]:

def str_parent_classes_python() -> str:
parents: list[str] = []
if not self.cpp_element().has_base_classes():
custom_derived = (
[] if not self.options.class_base_custom_derivation__callback
else self.options.class_base_custom_derivation__callback(self, True))

if not custom_derived and not self.cpp_element().has_base_classes():
return ""
for _access_type, base_class in self.cpp_element().base_classes():
class_python_scope = cpp_to_python.cpp_scope_to_pybind_scope_str(
self.options, base_class, include_self=True
)
parents.append(class_python_scope)

if custom_derived:
for custom_base in custom_derived:
parents.append(custom_base)
else:
for _access_type, base_class in self.cpp_element().base_classes():
class_python_scope = cpp_to_python.cpp_scope_to_pybind_scope_str(
self.options, base_class, include_self=True
)
parents.append(class_python_scope)
if len(parents) == 0:
return ""
else:
Expand Down Expand Up @@ -604,11 +613,19 @@ def make_pyclass_creation_code() -> str:

# fill py::class_ additional template params (base classes, nodelete, etc)
other_template_params_list = []
if self.cpp_element().has_base_classes():
custom_derived = (
[] if not self.options.class_base_custom_derivation__callback
else self.options.class_base_custom_derivation__callback(self, False))

if custom_derived:
for custom_base in custom_derived:
other_template_params_list.append(custom_base)
elif self.cpp_element().has_base_classes():
base_classes = self.cpp_element().base_classes()
for access_type, base_class in base_classes:
if access_type == CppAccessType.public or access_type == CppAccessType.protected:
other_template_params_list.append(base_class.cpp_scope_str(include_self=True))

if self.cpp_element().has_private_destructor() and options.bind_library == BindLibraryType.pybind11:
# nanobind does not support nodelete
other_template_params_list.append(f"std::unique_ptr<{qualified_struct_name}, py::nodelete>")
Expand Down
5 changes: 5 additions & 0 deletions src/litgen/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,11 @@ class LitgenOptions:
# the generated classes together
class_template_decorate_in_stub: bool = True

# This callback Callback to customize the base classes used in generated bindings
# First param is the AdoptedClass
# Second indicates context - True for python stub, False for CPP bindings
class_base_custom_derivation__callback: Callable[[Any, bool], list[str]] | None = None

# ------------------------------------------------------------------------------
# Adapt class members
# ------------------------------------------------------------------------------
Expand Down
83 changes: 83 additions & 0 deletions src/litgen/tests/option_custom_class_derive_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations
from codemanip import code_utils
import litgen

def test_custom_classes_base_option():
"""
Example of how the callback mechanism could be used in practice to handle reference return policies
"""

def handle_classes_base(cls, for_python_stub):

bases = []

elem = cls.cpp_element()

if elem.class_name == "SecondClass":
bases.append("FirstClass" if for_python_stub else "CustomNS::FirstClass")

return bases

"""
## First class is in another file, and won't be handled by litgen for another files
namespace CustomNS {
class FirstClass {
public:
FirstClass();
private:
int _value1;
};
}
"""

code = """
class SecondClass : CustomNS::FirstClass {
public:
SecondClass();
private:
int _value2;
};
class ThirdClass : SecondClass {
public:
ThirdClass();
private:
int _value3;
};
"""

options = litgen.LitgenOptions()
options.class_base_custom_derivation__callback = handle_classes_base
generated_code = litgen.generate_code(options, code)

code_utils.assert_are_codes_equal(
generated_code.pydef_code,
"""
auto pyClassSecondClass =
py::class_<SecondClass, CustomNS::FirstClass>
(m, "SecondClass", "")
.def(py::init<>())
;
auto pyClassThirdClass =
py::class_<ThirdClass, SecondClass>
(m, "ThirdClass", "")
.def(py::init<>())
;
""",
)

code_utils.assert_are_codes_equal(
generated_code.stub_code,
"""
class SecondClass(FirstClass):
def __init__(self) -> None:
pass
class ThirdClass(SecondClass):
def __init__(self) -> None:
pass
""",
)

0 comments on commit f5fb5c0

Please sign in to comment.