diff --git a/wrap/.github/workflows/linux-ci.yml b/wrap/.github/workflows/linux-ci.yml index 0ca9ba8f5b..34623385ed 100644 --- a/wrap/.github/workflows/linux-ci.yml +++ b/wrap/.github/workflows/linux-ci.yml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.5, 3.6, 3.7, 3.8, 3.9] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - name: Checkout diff --git a/wrap/.github/workflows/macos-ci.yml b/wrap/.github/workflows/macos-ci.yml index b0ccb3fbe9..3910d28d8a 100644 --- a/wrap/.github/workflows/macos-ci.yml +++ b/wrap/.github/workflows/macos-ci.yml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.5, 3.6, 3.7, 3.8, 3.9] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - name: Checkout diff --git a/wrap/.gitignore b/wrap/.gitignore index 8e2bafa7a0..9f79deafab 100644 --- a/wrap/.gitignore +++ b/wrap/.gitignore @@ -8,4 +8,4 @@ __pycache__/ # Files related to code coverage stats **/.coverage -gtwrap/matlab_wrapper.tpl +gtwrap/matlab_wrapper/matlab_wrapper.tpl diff --git a/wrap/CMakeLists.txt b/wrap/CMakeLists.txt index 9e03da0607..2a11a760d4 100644 --- a/wrap/CMakeLists.txt +++ b/wrap/CMakeLists.txt @@ -58,7 +58,7 @@ if(NOT DEFINED GTWRAP_INCLUDE_NAME) endif() configure_file(${PROJECT_SOURCE_DIR}/templates/matlab_wrapper.tpl.in - ${PROJECT_SOURCE_DIR}/gtwrap/matlab_wrapper.tpl) + ${PROJECT_SOURCE_DIR}/gtwrap/matlab_wrapper/matlab_wrapper.tpl) # Install the gtwrap python package as a directory so it can be found by CMake # for wrapping. diff --git a/wrap/DOCS.md b/wrap/DOCS.md index 8537ddd276..c8285baeff 100644 --- a/wrap/DOCS.md +++ b/wrap/DOCS.md @@ -192,12 +192,14 @@ The python wrapper supports keyword arguments for functions/methods. Hence, the - **DO NOT** re-define an overriden function already declared in the external (forward-declared) base class. This will cause an ambiguity problem in the Pybind header file. +- Splitting wrapper over multiple files + - The Pybind11 wrapper supports splitting the wrapping code over multiple files. + - To be able to use classes from another module, simply import the C++ header file in that wrapper file. + - Unfortunately, this means that aliases can no longer be used. + - Similarly, there can be multiple `preamble.h` and `specializations.h` files. Each of these should match the module file name. ### TODO -- Default values for arguments. - - WORKAROUND: make multiple versions of the same function for different configurations of default arguments. - Handle `gtsam::Rot3M` conversions to quaternions. - Parse return of const ref arguments. - Parse `std::string` variants and convert directly to special string. -- Add enum support. - Add generalized serialization support via `boost.serialization` with hooks to MATLAB save/load. diff --git a/wrap/README.md b/wrap/README.md index 442fc2f934..a04a2ef2d0 100644 --- a/wrap/README.md +++ b/wrap/README.md @@ -29,8 +29,10 @@ Using `wrap` in your project is straightforward from here. In your `CMakeLists.t ```cmake find_package(gtwrap) +set(interface_files ${PROJECT_SOURCE_DIR}/cpp/${PROJECT_NAME}.h) + pybind_wrap(${PROJECT_NAME}_py # target - ${PROJECT_SOURCE_DIR}/cpp/${PROJECT_NAME}.h # interface header file + "${interface_files}" # list of interface header files "${PROJECT_NAME}.cpp" # the generated cpp "${PROJECT_NAME}" # module_name "${PROJECT_MODULE_NAME}" # top namespace in the cpp file e.g. gtsam diff --git a/wrap/gtwrap/interface_parser/__init__.py b/wrap/gtwrap/interface_parser/__init__.py index 0f87eaaa9d..3be52d7d9f 100644 --- a/wrap/gtwrap/interface_parser/__init__.py +++ b/wrap/gtwrap/interface_parser/__init__.py @@ -12,7 +12,7 @@ import sys -import pyparsing +import pyparsing # type: ignore from .classes import * from .declaration import * diff --git a/wrap/gtwrap/interface_parser/classes.py b/wrap/gtwrap/interface_parser/classes.py index ea7a3b3c38..3e6a0fc3c7 100644 --- a/wrap/gtwrap/interface_parser/classes.py +++ b/wrap/gtwrap/interface_parser/classes.py @@ -12,7 +12,7 @@ from typing import Iterable, List, Union -from pyparsing import Literal, Optional, ZeroOrMore +from pyparsing import Literal, Optional, ZeroOrMore # type: ignore from .enum import Enum from .function import ArgumentList, ReturnType @@ -233,7 +233,7 @@ def __init__(self, self.static_methods = [] self.properties = [] self.operators = [] - self.enums = [] + self.enums: List[Enum] = [] for m in members: if isinstance(m, Constructor): self.ctors.append(m) @@ -274,7 +274,7 @@ def __init__(self, def __init__( self, - template: Template, + template: Union[Template, None], is_virtual: str, name: str, parent_class: list, @@ -292,16 +292,16 @@ def __init__( if parent_class: # If it is in an iterable, extract the parent class. if isinstance(parent_class, Iterable): - parent_class = parent_class[0] + parent_class = parent_class[0] # type: ignore # If the base class is a TemplatedType, # we want the instantiated Typename if isinstance(parent_class, TemplatedType): - parent_class = parent_class.typename + parent_class = parent_class.typename # type: ignore self.parent_class = parent_class else: - self.parent_class = '' + self.parent_class = '' # type: ignore self.ctors = ctors self.methods = methods diff --git a/wrap/gtwrap/interface_parser/declaration.py b/wrap/gtwrap/interface_parser/declaration.py index 292d6aeaa6..f47ee6e057 100644 --- a/wrap/gtwrap/interface_parser/declaration.py +++ b/wrap/gtwrap/interface_parser/declaration.py @@ -10,7 +10,7 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert """ -from pyparsing import CharsNotIn, Optional +from pyparsing import CharsNotIn, Optional # type: ignore from .tokens import (CLASS, COLON, INCLUDE, LOPBRACK, ROPBRACK, SEMI_COLON, VIRTUAL) diff --git a/wrap/gtwrap/interface_parser/enum.py b/wrap/gtwrap/interface_parser/enum.py index fca7080ef2..265e1ad612 100644 --- a/wrap/gtwrap/interface_parser/enum.py +++ b/wrap/gtwrap/interface_parser/enum.py @@ -10,7 +10,7 @@ Author: Varun Agrawal """ -from pyparsing import delimitedList +from pyparsing import delimitedList # type: ignore from .tokens import ENUM, IDENT, LBRACE, RBRACE, SEMI_COLON from .type import Typename diff --git a/wrap/gtwrap/interface_parser/function.py b/wrap/gtwrap/interface_parser/function.py index 3b9a5d4ada..995aba10e1 100644 --- a/wrap/gtwrap/interface_parser/function.py +++ b/wrap/gtwrap/interface_parser/function.py @@ -10,9 +10,9 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert """ -from typing import Iterable, List, Union +from typing import Any, Iterable, List, Union -from pyparsing import Optional, ParseResults, delimitedList +from pyparsing import Optional, ParseResults, delimitedList # type: ignore from .template import Template from .tokens import (COMMA, DEFAULT_ARG, EQUAL, IDENT, LOPBRACK, LPAREN, PAIR, @@ -42,12 +42,12 @@ def __init__(self, name: str, default: ParseResults = None): if isinstance(ctype, Iterable): - self.ctype = ctype[0] + self.ctype = ctype[0] # type: ignore else: self.ctype = ctype self.name = name self.default = default - self.parent = None # type: Union[ArgumentList, None] + self.parent: Union[ArgumentList, None] = None def __repr__(self) -> str: return self.to_cpp() @@ -70,7 +70,7 @@ def __init__(self, args_list: List[Argument]): arg.parent = self # The parent object which contains the argument list # E.g. Method, StaticMethod, Template, Constructor, GlobalFunction - self.parent = None + self.parent: Any = None @staticmethod def from_parse_result(parse_result: ParseResults): @@ -123,7 +123,7 @@ def __init__(self, type1: Union[Type, TemplatedType], type2: Type): self.type2 = type2 # The parent object which contains the return type # E.g. Method, StaticMethod, Template, Constructor, GlobalFunction - self.parent = None + self.parent: Any = None def is_void(self) -> bool: """ diff --git a/wrap/gtwrap/interface_parser/module.py b/wrap/gtwrap/interface_parser/module.py index 6412098b8a..7912c40d5b 100644 --- a/wrap/gtwrap/interface_parser/module.py +++ b/wrap/gtwrap/interface_parser/module.py @@ -12,7 +12,8 @@ # pylint: disable=unnecessary-lambda, unused-import, expression-not-assigned, no-else-return, protected-access, too-few-public-methods, too-many-arguments -from pyparsing import ParseResults, ZeroOrMore, cppStyleComment, stringEnd +from pyparsing import (ParseResults, ZeroOrMore, # type: ignore + cppStyleComment, stringEnd) from .classes import Class from .declaration import ForwardDeclaration, Include diff --git a/wrap/gtwrap/interface_parser/namespace.py b/wrap/gtwrap/interface_parser/namespace.py index 575d982371..9c135ffe8c 100644 --- a/wrap/gtwrap/interface_parser/namespace.py +++ b/wrap/gtwrap/interface_parser/namespace.py @@ -14,7 +14,7 @@ from typing import List, Union -from pyparsing import Forward, ParseResults, ZeroOrMore +from pyparsing import Forward, ParseResults, ZeroOrMore # type: ignore from .classes import Class, collect_namespaces from .declaration import ForwardDeclaration, Include @@ -93,7 +93,7 @@ def from_parse_result(t: ParseResults): return Namespace(t.name, content) def find_class_or_function( - self, typename: Typename) -> Union[Class, GlobalFunction]: + self, typename: Typename) -> Union[Class, GlobalFunction, ForwardDeclaration]: """ Find the Class or GlobalFunction object given its typename. We have to traverse the tree of namespaces. @@ -115,7 +115,7 @@ def find_class_or_function( return res[0] def top_level(self) -> "Namespace": - """Return the top leve namespace.""" + """Return the top level namespace.""" if self.name == '' or self.parent == '': return self else: diff --git a/wrap/gtwrap/interface_parser/template.py b/wrap/gtwrap/interface_parser/template.py index dc9d0ce44f..fd9de830ae 100644 --- a/wrap/gtwrap/interface_parser/template.py +++ b/wrap/gtwrap/interface_parser/template.py @@ -12,11 +12,11 @@ from typing import List -from pyparsing import Optional, ParseResults, delimitedList +from pyparsing import Optional, ParseResults, delimitedList # type: ignore from .tokens import (EQUAL, IDENT, LBRACE, LOPBRACK, RBRACE, ROPBRACK, SEMI_COLON, TEMPLATE, TYPEDEF) -from .type import Typename, TemplatedType +from .type import TemplatedType, Typename class Template: diff --git a/wrap/gtwrap/interface_parser/tokens.py b/wrap/gtwrap/interface_parser/tokens.py index 4eba95900a..0f8d38d868 100644 --- a/wrap/gtwrap/interface_parser/tokens.py +++ b/wrap/gtwrap/interface_parser/tokens.py @@ -10,9 +10,9 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert """ -from pyparsing import (Keyword, Literal, OneOrMore, Or, QuotedString, Suppress, - Word, alphanums, alphas, nestedExpr, nums, - originalTextFor, printables) +from pyparsing import (Keyword, Literal, OneOrMore, Or, # type: ignore + QuotedString, Suppress, Word, alphanums, alphas, + nestedExpr, nums, originalTextFor, printables) # rule for identifiers (e.g. variable names) IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums) diff --git a/wrap/gtwrap/interface_parser/type.py b/wrap/gtwrap/interface_parser/type.py index b9f2bd8f74..0b9be65017 100644 --- a/wrap/gtwrap/interface_parser/type.py +++ b/wrap/gtwrap/interface_parser/type.py @@ -14,7 +14,8 @@ from typing import Iterable, List, Union -from pyparsing import Forward, Optional, Or, ParseResults, delimitedList +from pyparsing import (Forward, Optional, Or, ParseResults, # type: ignore + delimitedList) from .tokens import (BASIS_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF, ROPBRACK, SHARED_POINTER) @@ -48,7 +49,7 @@ class Typename: def __init__(self, t: ParseResults, - instantiations: Union[tuple, list, str, ParseResults] = ()): + instantiations: Iterable[ParseResults] = ()): self.name = t[-1] # the name is the last element in this list self.namespaces = t[:-1] diff --git a/wrap/gtwrap/interface_parser/variable.py b/wrap/gtwrap/interface_parser/variable.py index fcb02666f7..3779cf74fa 100644 --- a/wrap/gtwrap/interface_parser/variable.py +++ b/wrap/gtwrap/interface_parser/variable.py @@ -10,7 +10,9 @@ Author: Varun Agrawal, Gerry Chen """ -from pyparsing import Optional, ParseResults +from typing import List + +from pyparsing import Optional, ParseResults # type: ignore from .tokens import DEFAULT_ARG, EQUAL, IDENT, SEMI_COLON from .type import TemplatedType, Type @@ -40,7 +42,7 @@ class Hello { t.default[0] if isinstance(t.default, ParseResults) else None)) def __init__(self, - ctype: Type, + ctype: List[Type], name: str, default: ParseResults = None, parent=''): diff --git a/wrap/gtwrap/matlab_wrapper/__init__.py b/wrap/gtwrap/matlab_wrapper/__init__.py new file mode 100644 index 0000000000..f10338c1c7 --- /dev/null +++ b/wrap/gtwrap/matlab_wrapper/__init__.py @@ -0,0 +1,3 @@ +"""Package to wrap C++ code to Matlab via MEX.""" + +from .wrapper import MatlabWrapper diff --git a/wrap/gtwrap/matlab_wrapper/mixins.py b/wrap/gtwrap/matlab_wrapper/mixins.py new file mode 100644 index 0000000000..061cea2833 --- /dev/null +++ b/wrap/gtwrap/matlab_wrapper/mixins.py @@ -0,0 +1,222 @@ +"""Mixins for reducing the amount of boilerplate in the main wrapper class.""" + +import gtwrap.interface_parser as parser +import gtwrap.template_instantiator as instantiator + + +class CheckMixin: + """Mixin to provide various checks.""" + # Data types that are primitive types + not_ptr_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t'] + # Ignore the namespace for these datatypes + ignore_namespace = ['Matrix', 'Vector', 'Point2', 'Point3'] + # Methods that should be ignored + ignore_methods = ['pickle'] + # Methods that should not be wrapped directly + whitelist = ['serializable', 'serialize'] + # Datatypes that do not need to be checked in methods + not_check_type: list = [] + + def _has_serialization(self, cls): + for m in cls.methods: + if m.name in self.whitelist: + return True + return False + + def is_shared_ptr(self, arg_type): + """ + Determine if the `interface_parser.Type` should be treated as a + shared pointer in the wrapper. + """ + return arg_type.is_shared_ptr or ( + arg_type.typename.name not in self.not_ptr_type + and arg_type.typename.name not in self.ignore_namespace + and arg_type.typename.name != 'string') + + def is_ptr(self, arg_type): + """ + Determine if the `interface_parser.Type` should be treated as a + raw pointer in the wrapper. + """ + return arg_type.is_ptr or ( + arg_type.typename.name not in self.not_ptr_type + and arg_type.typename.name not in self.ignore_namespace + and arg_type.typename.name != 'string') + + def is_ref(self, arg_type): + """ + Determine if the `interface_parser.Type` should be treated as a + reference in the wrapper. + """ + return arg_type.typename.name not in self.ignore_namespace and \ + arg_type.typename.name not in self.not_ptr_type and \ + arg_type.is_ref + + +class FormatMixin: + """Mixin to provide formatting utilities.""" + def _clean_class_name(self, instantiated_class): + """Reformatted the C++ class name to fit Matlab defined naming + standards + """ + if len(instantiated_class.ctors) != 0: + return instantiated_class.ctors[0].name + + return instantiated_class.name + + def _format_type_name(self, + type_name, + separator='::', + include_namespace=True, + constructor=False, + method=False): + """ + Args: + type_name: an interface_parser.Typename to reformat + separator: the statement to add between namespaces and typename + include_namespace: whether to include namespaces when reformatting + constructor: if the typename will be in a constructor + method: if the typename will be in a method + + Raises: + constructor and method cannot both be true + """ + if constructor and method: + raise ValueError( + 'Constructor and method parameters cannot both be True') + + formatted_type_name = '' + name = type_name.name + + if include_namespace: + for namespace in type_name.namespaces: + if name not in self.ignore_namespace and namespace != '': + formatted_type_name += namespace + separator + + if constructor: + formatted_type_name += self.data_type.get(name) or name + elif method: + formatted_type_name += self.data_type_param.get(name) or name + else: + formatted_type_name += name + + if separator == "::": # C++ + templates = [] + for idx in range(len(type_name.instantiations)): + template = '{}'.format( + self._format_type_name(type_name.instantiations[idx], + include_namespace=include_namespace, + constructor=constructor, + method=method)) + templates.append(template) + + if len(templates) > 0: # If there are no templates + formatted_type_name += '<{}>'.format(','.join(templates)) + + else: + for idx in range(len(type_name.instantiations)): + formatted_type_name += '{}'.format( + self._format_type_name(type_name.instantiations[idx], + separator=separator, + include_namespace=False, + constructor=constructor, + method=method)) + + return formatted_type_name + + def _format_return_type(self, + return_type, + include_namespace=False, + separator="::"): + """Format return_type. + + Args: + return_type: an interface_parser.ReturnType to reformat + include_namespace: whether to include namespaces when reformatting + """ + return_wrap = '' + + if self._return_count(return_type) == 1: + return_wrap = self._format_type_name( + return_type.type1.typename, + separator=separator, + include_namespace=include_namespace) + else: + return_wrap = 'pair< {type1}, {type2} >'.format( + type1=self._format_type_name( + return_type.type1.typename, + separator=separator, + include_namespace=include_namespace), + type2=self._format_type_name( + return_type.type2.typename, + separator=separator, + include_namespace=include_namespace)) + + return return_wrap + + def _format_class_name(self, instantiated_class, separator=''): + """Format a template_instantiator.InstantiatedClass name.""" + if instantiated_class.parent == '': + parent_full_ns = [''] + else: + parent_full_ns = instantiated_class.parent.full_namespaces() + # class_name = instantiated_class.parent.name + # + # if class_name != '': + # class_name += separator + # + # class_name += instantiated_class.name + parentname = "".join([separator + x + for x in parent_full_ns]) + separator + + class_name = parentname[2 * len(separator):] + + class_name += instantiated_class.name + + return class_name + + def _format_static_method(self, static_method, separator=''): + """Example: + + gtsamPoint3.staticFunction + """ + method = '' + + if isinstance(static_method, parser.StaticMethod): + method += "".join([separator + x for x in static_method.parent.namespaces()]) + \ + separator + static_method.parent.name + separator + + return method[2 * len(separator):] + + def _format_instance_method(self, instance_method, separator=''): + """Example: + + gtsamPoint3.staticFunction + """ + method = '' + + if isinstance(instance_method, instantiator.InstantiatedMethod): + method_list = [ + separator + x + for x in instance_method.parent.parent.full_namespaces() + ] + method += "".join(method_list) + separator + + method += instance_method.parent.name + separator + method += instance_method.original.name + method += "<" + instance_method.instantiations.to_cpp() + ">" + + return method[2 * len(separator):] + + def _format_global_method(self, static_method, separator=''): + """Example: + + gtsamPoint3.staticFunction + """ + method = '' + + if isinstance(static_method, parser.GlobalFunction): + method += "".join([separator + x for x in static_method.parent.full_namespaces()]) + \ + separator + + return method[2 * len(separator):] diff --git a/wrap/gtwrap/matlab_wrapper/templates.py b/wrap/gtwrap/matlab_wrapper/templates.py new file mode 100644 index 0000000000..7aaf8f487b --- /dev/null +++ b/wrap/gtwrap/matlab_wrapper/templates.py @@ -0,0 +1,166 @@ +import textwrap + + +class WrapperTemplate: + """Class to encapsulate string templates for use in wrapper generation""" + boost_headers = textwrap.dedent(""" + #include + #include + #include + """) + + typdef_collectors = textwrap.dedent('''\ + typedef std::set*> Collector_{class_name}; + static Collector_{class_name} collector_{class_name}; + ''') + + delete_obj = textwrap.indent(textwrap.dedent('''\ + {{ for(Collector_{class_name}::iterator iter = collector_{class_name}.begin(); + iter != collector_{class_name}.end(); ) {{ + delete *iter; + collector_{class_name}.erase(iter++); + anyDeleted = true; + }} }} + '''), + prefix=' ') + + delete_all_objects = textwrap.dedent(''' + void _deleteAllObjects() + {{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout);\n + bool anyDeleted = false; + {delete_objs} + if(anyDeleted) + cout << + "WARNING: Wrap modules with variables in the workspace have been reloaded due to\\n" + "calling destructors, call \'clear all\' again if you plan to now recompile a wrap\\n" + "module, so that your recompiled module is used instead of the old one." << endl; + std::cout.rdbuf(outbuf); + }} + ''') + + rtti_register = textwrap.dedent('''\ + void _{module_name}_RTTIRegister() {{ + const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_{module_name}_rttiRegistry_created"); + if(!alreadyCreated) {{ + std::map types; + + {rtti_classes} + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); + if(!registry) + registry = mxCreateStructMatrix(1, 1, 0, NULL); + typedef std::pair StringPair; + for(const StringPair& rtti_matlab: types) {{ + int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); + if(fieldId < 0) {{ + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + }} + mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); + mxSetFieldByNumber(registry, 0, fieldId, matlabName); + }} + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) {{ + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + }} + mxDestroyArray(registry); + + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) {{ + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + }} + mxDestroyArray(newAlreadyCreated); + }} + }} + ''') + + collector_function_upcast_from_void = textwrap.dedent('''\ + void {class_name}_upcastFromVoid_{id}(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr<{cpp_name}> Shared; + boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(boost::static_pointer_cast<{cpp_name}>(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; + }}\n + ''') + + class_serialize_method = textwrap.dedent('''\ + function varargout = string_serialize(this, varargin) + % STRING_SERIALIZE usage: string_serialize() : returns string + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + varargout{{1}} = {wrapper}({wrapper_id}, this, varargin{{:}}); + else + error('Arguments do not match any overload of function {class_name}.string_serialize'); + end + end\n + function sobj = saveobj(obj) + % SAVEOBJ Saves the object to a matlab-readable format + sobj = obj.string_serialize(); + end + ''') + + collector_function_serialize = textwrap.indent(textwrap.dedent("""\ + typedef boost::shared_ptr<{full_name}> Shared; + checkArguments("string_serialize",nargout,nargin-1,0); + Shared obj = unwrap_shared_ptr<{full_name}>(in[0], "ptr_{namespace}{class_name}"); + ostringstream out_archive_stream; + boost::archive::text_oarchive out_archive(out_archive_stream); + out_archive << *obj; + out[0] = wrap< string >(out_archive_stream.str()); + """), + prefix=' ') + + collector_function_deserialize = textwrap.indent(textwrap.dedent("""\ + typedef boost::shared_ptr<{full_name}> Shared; + checkArguments("{namespace}{class_name}.string_deserialize",nargout,nargin,1); + string serialized = unwrap< string >(in[0]); + istringstream in_archive_stream(serialized); + boost::archive::text_iarchive in_archive(in_archive_stream); + Shared output(new {full_name}()); + in_archive >> *output; + out[0] = wrap_shared_ptr(output,"{namespace}.{class_name}", false); + """), + prefix=' ') + + mex_function = textwrap.dedent(''' + void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) + {{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout);\n + _{module_name}_RTTIRegister();\n + int id = unwrap(in[0]);\n + try {{ + switch(id) {{ + {cases} }} + }} catch(const std::exception& e) {{ + mexErrMsgTxt(("Exception from gtsam:\\n" + std::string(e.what()) + "\\n").c_str()); + }}\n + std::cout.rdbuf(outbuf); + }} + ''') + + collector_function_shared_return = textwrap.indent(textwrap.dedent('''\ + {{ + boost::shared_ptr<{name}> shared({shared_obj}); + out[{id}] = wrap_shared_ptr(shared,"{name}"); + }}{new_line}'''), + prefix=' ') + + matlab_deserialize = textwrap.indent(textwrap.dedent("""\ + function varargout = string_deserialize(varargin) + % STRING_DESERIALIZE usage: string_deserialize() : returns {class_name} + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 1 + varargout{{1}} = {wrapper}({id}, varargin{{:}}); + else + error('Arguments do not match any overload of function {class_name}.string_deserialize'); + end + end\n + function obj = loadobj(sobj) + % LOADOBJ Saves the object to a matlab-readable format + obj = {class_name}.string_deserialize(sobj); + end + """), + prefix=' ') diff --git a/wrap/gtwrap/matlab_wrapper.py b/wrap/gtwrap/matlab_wrapper/wrapper.py similarity index 68% rename from wrap/gtwrap/matlab_wrapper.py rename to wrap/gtwrap/matlab_wrapper/wrapper.py index de6221bbcf..b040d27311 100755 --- a/wrap/gtwrap/matlab_wrapper.py +++ b/wrap/gtwrap/matlab_wrapper/wrapper.py @@ -7,16 +7,19 @@ import os import os.path as osp -import sys import textwrap from functools import partial, reduce from typing import Dict, Iterable, List, Union +from loguru import logger + import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator +from gtwrap.matlab_wrapper.mixins import CheckMixin, FormatMixin +from gtwrap.matlab_wrapper.templates import WrapperTemplate -class MatlabWrapper(object): +class MatlabWrapper(CheckMixin, FormatMixin): """ Wrap the given C++ code into Matlab. Attributes @@ -25,89 +28,75 @@ class MatlabWrapper(object): top_module_namespace: C++ namespace for the top module (default '') ignore_classes: A list of classes to ignore (default []) """ - # Map the data type to its Matlab class. - # Found in Argument.cpp in old wrapper - data_type = { - 'string': 'char', - 'char': 'char', - 'unsigned char': 'unsigned char', - 'Vector': 'double', - 'Matrix': 'double', - 'int': 'numeric', - 'size_t': 'numeric', - 'bool': 'logical' - } - # Map the data type into the type used in Matlab methods. - # Found in matlab.h in old wrapper - data_type_param = { - 'string': 'char', - 'char': 'char', - 'unsigned char': 'unsigned char', - 'size_t': 'int', - 'int': 'int', - 'double': 'double', - 'Point2': 'double', - 'Point3': 'double', - 'Vector': 'double', - 'Matrix': 'double', - 'bool': 'bool' - } - # Methods that should not be wrapped directly - whitelist = ['serializable', 'serialize'] - # Methods that should be ignored - ignore_methods = ['pickle'] - # Datatypes that do not need to be checked in methods - not_check_type = [] # type: list - # Data types that are primitive types - not_ptr_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t'] - # Ignore the namespace for these datatypes - ignore_namespace = ['Matrix', 'Vector', 'Point2', 'Point3'] - # The amount of times the wrapper has created a call to geometry_wrapper - wrapper_id = 0 - # Map each wrapper id to what its collector function namespace, class, type, and string format - wrapper_map = {} - # Set of all the includes in the namespace - includes = {} # type: Dict[parser.Include, int] - # Set of all classes in the namespace - classes = [ - ] # type: List[Union[parser.Class, instantiator.InstantiatedClass]] - classes_elems = { - } # type: Dict[Union[parser.Class, instantiator.InstantiatedClass], int] - # Id for ordering global functions in the wrapper - global_function_id = 0 - # Files and their content - content = [] # type: List[str] - - # Ensure the template file is always picked up from the correct directory. - dir_path = osp.dirname(osp.realpath(__file__)) - with open(osp.join(dir_path, "matlab_wrapper.tpl")) as f: - wrapper_file_header = f.read() - def __init__(self, module_name, top_module_namespace='', ignore_classes=()): + super().__init__() + self.module_name = module_name self.top_module_namespace = top_module_namespace self.ignore_classes = ignore_classes self.verbose = False - def _debug(self, message): - if not self.verbose: - return - print(message, file=sys.stderr) - - def _add_include(self, include): - self.includes[include] = 0 - - def _add_class(self, instantiated_class): + # Map the data type to its Matlab class. + # Found in Argument.cpp in old wrapper + self.data_type = { + 'string': 'char', + 'char': 'char', + 'unsigned char': 'unsigned char', + 'Vector': 'double', + 'Matrix': 'double', + 'int': 'numeric', + 'size_t': 'numeric', + 'bool': 'logical' + } + # Map the data type into the type used in Matlab methods. + # Found in matlab.h in old wrapper + self.data_type_param = { + 'string': 'char', + 'char': 'char', + 'unsigned char': 'unsigned char', + 'size_t': 'int', + 'int': 'int', + 'double': 'double', + 'Point2': 'double', + 'Point3': 'double', + 'Vector': 'double', + 'Matrix': 'double', + 'bool': 'bool' + } + # The amount of times the wrapper has created a call to geometry_wrapper + self.wrapper_id = 0 + # Map each wrapper id to its collector function namespace, class, type, and string format + self.wrapper_map: Dict = {} + # Set of all the includes in the namespace + self.includes: List[parser.Include] = [] + # Set of all classes in the namespace + self.classes: List[Union[parser.Class, + instantiator.InstantiatedClass]] = [] + self.classes_elems: Dict[Union[parser.Class, + instantiator.InstantiatedClass], + int] = {} + # Id for ordering global functions in the wrapper + self.global_function_id = 0 + # Files and their content + self.content: List[str] = [] + + # Ensure the template file is always picked up from the correct directory. + dir_path = osp.dirname(osp.realpath(__file__)) + with open(osp.join(dir_path, "matlab_wrapper.tpl")) as f: + self.wrapper_file_headers = f.read() + + def add_class(self, instantiated_class): + """Add `instantiated_class` to the list of classes.""" if self.classes_elems.get(instantiated_class) is None: self.classes_elems[instantiated_class] = 0 self.classes.append(instantiated_class) def _update_wrapper_id(self, collector_function=None, id_diff=0): - """Get and define wrapper ids. - + """ + Get and define wrapper ids. Generates the map of id -> collector function. Args: @@ -150,34 +139,6 @@ def _insert_spaces(self, x, y): """ return x + '\n' + ('' if y == '' else ' ') + y - def _is_shared_ptr(self, arg_type): - """ - Determine if the `interface_parser.Type` should be treated as a - shared pointer in the wrapper. - """ - return arg_type.is_shared_ptr or ( - arg_type.typename.name not in self.not_ptr_type - and arg_type.typename.name not in self.ignore_namespace - and arg_type.typename.name != 'string') - - def _is_ptr(self, arg_type): - """ - Determine if the `interface_parser.Type` should be treated as a - raw pointer in the wrapper. - """ - return arg_type.is_ptr or ( - arg_type.typename.name not in self.not_ptr_type - and arg_type.typename.name not in self.ignore_namespace - and arg_type.typename.name != 'string') - - def _is_ref(self, arg_type): - """Determine if the interface_parser.Type should be treated as a - reference in the wrapper. - """ - return arg_type.typename.name not in self.ignore_namespace and \ - arg_type.typename.name not in self.not_ptr_type and \ - arg_type.is_ref - def _group_methods(self, methods): """Group overloaded methods together""" method_map = {} @@ -190,181 +151,10 @@ def _group_methods(self, methods): method_map[method.name] = len(method_out) method_out.append([method]) else: - self._debug("[_group_methods] Merging {} with {}".format( - method_index, method.name)) method_out[method_index].append(method) return method_out - def _clean_class_name(self, instantiated_class): - """Reformatted the C++ class name to fit Matlab defined naming - standards - """ - if len(instantiated_class.ctors) != 0: - return instantiated_class.ctors[0].name - - return instantiated_class.name - - @classmethod - def _format_type_name(cls, - type_name, - separator='::', - include_namespace=True, - constructor=False, - method=False): - """ - Args: - type_name: an interface_parser.Typename to reformat - separator: the statement to add between namespaces and typename - include_namespace: whether to include namespaces when reformatting - constructor: if the typename will be in a constructor - method: if the typename will be in a method - - Raises: - constructor and method cannot both be true - """ - if constructor and method: - raise Exception( - 'Constructor and method parameters cannot both be True') - - formatted_type_name = '' - name = type_name.name - - if include_namespace: - for namespace in type_name.namespaces: - if name not in cls.ignore_namespace and namespace != '': - formatted_type_name += namespace + separator - - #self._debug("formatted_ns: {}, ns: {}".format(formatted_type_name, type_name.namespaces)) - if constructor: - formatted_type_name += cls.data_type.get(name) or name - elif method: - formatted_type_name += cls.data_type_param.get(name) or name - else: - formatted_type_name += name - - if separator == "::": # C++ - templates = [] - for idx in range(len(type_name.instantiations)): - template = '{}'.format( - cls._format_type_name(type_name.instantiations[idx], - include_namespace=include_namespace, - constructor=constructor, - method=method)) - templates.append(template) - - if len(templates) > 0: # If there are no templates - formatted_type_name += '<{}>'.format(','.join(templates)) - - else: - for idx in range(len(type_name.instantiations)): - formatted_type_name += '{}'.format( - cls._format_type_name(type_name.instantiations[idx], - separator=separator, - include_namespace=False, - constructor=constructor, - method=method)) - - return formatted_type_name - - @classmethod - def _format_return_type(cls, - return_type, - include_namespace=False, - separator="::"): - """Format return_type. - - Args: - return_type: an interface_parser.ReturnType to reformat - include_namespace: whether to include namespaces when reformatting - """ - return_wrap = '' - - if cls._return_count(return_type) == 1: - return_wrap = cls._format_type_name( - return_type.type1.typename, - separator=separator, - include_namespace=include_namespace) - else: - return_wrap = 'pair< {type1}, {type2} >'.format( - type1=cls._format_type_name( - return_type.type1.typename, - separator=separator, - include_namespace=include_namespace), - type2=cls._format_type_name( - return_type.type2.typename, - separator=separator, - include_namespace=include_namespace)) - - return return_wrap - - def _format_class_name(self, instantiated_class, separator=''): - """Format a template_instantiator.InstantiatedClass name.""" - if instantiated_class.parent == '': - parent_full_ns = [''] - else: - parent_full_ns = instantiated_class.parent.full_namespaces() - # class_name = instantiated_class.parent.name - # - # if class_name != '': - # class_name += separator - # - # class_name += instantiated_class.name - parentname = "".join([separator + x - for x in parent_full_ns]) + separator - - class_name = parentname[2 * len(separator):] - - class_name += instantiated_class.name - - return class_name - - def _format_static_method(self, static_method, separator=''): - """Example: - - gtsamPoint3.staticFunction - """ - method = '' - - if isinstance(static_method, parser.StaticMethod): - method += "".join([separator + x for x in static_method.parent.namespaces()]) + \ - separator + static_method.parent.name + separator - - return method[2 * len(separator):] - - def _format_instance_method(self, instance_method, separator=''): - """Example: - - gtsamPoint3.staticFunction - """ - method = '' - - if isinstance(instance_method, instantiator.InstantiatedMethod): - method_list = [ - separator + x - for x in instance_method.parent.parent.full_namespaces() - ] - method += "".join(method_list) + separator - - method += instance_method.parent.name + separator - method += instance_method.original.name - method += "<" + instance_method.instantiations.to_cpp() + ">" - - return method[2 * len(separator):] - - def _format_global_method(self, static_method, separator=''): - """Example: - - gtsamPoint3.staticFunction - """ - method = '' - - if isinstance(static_method, parser.GlobalFunction): - method += "".join([separator + x for x in static_method.parent.full_namespaces()]) + \ - separator - - return method[2 * len(separator):] - def _wrap_args(self, args): """Wrap an interface_parser.ArgumentList into a list of arguments. @@ -520,7 +310,7 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): if params != '': params += ',' - if self._is_ref(arg.ctype): # and not constructor: + if self.is_ref(arg.ctype): # and not constructor: ctype_camel = self._format_type_name(arg.ctype.typename, separator='') body_args += textwrap.indent(textwrap.dedent('''\ @@ -531,7 +321,7 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): id=arg_id)), prefix=' ') - elif (self._is_shared_ptr(arg.ctype) or self._is_ptr(arg.ctype)) and \ + elif (self.is_shared_ptr(arg.ctype) or self.is_ptr(arg.ctype)) and \ arg.ctype.typename.name not in self.ignore_namespace: if arg.ctype.is_shared_ptr: call_type = arg.ctype.is_shared_ptr @@ -665,22 +455,13 @@ def class_comment(self, instantiated_class): return comment - def generate_matlab_wrapper(self): - """Generate the C++ file for the wrapper.""" - file_name = self._wrapper_name() + '.cpp' - - wrapper_file = self.wrapper_file_header - - return file_name, wrapper_file - def wrap_method(self, methods): - """Wrap methods in the body of a class.""" + """ + Wrap methods in the body of a class. + """ if not isinstance(methods, list): methods = [methods] - # for method in methods: - # output = '' - return '' def wrap_methods(self, methods, global_funcs=False, global_ns=None): @@ -697,10 +478,6 @@ def wrap_methods(self, methods, global_funcs=False, global_ns=None): continue if global_funcs: - self._debug("[wrap_methods] wrapping: {}..{}={}".format( - method[0].parent.name, method[0].name, - type(method[0].parent.name))) - method_text = self.wrap_global_function(method) self.content.append(("".join([ '+' + x + '/' for x in global_ns.full_namespaces()[1:] @@ -838,11 +615,6 @@ def wrap_class_constructors(self, namespace_name, inst_class, parent_name, base_obj = '' - if has_parent: - self._debug("class: {} ns: {}".format( - parent_name, - self._format_class_name(inst_class.parent, separator="."))) - if has_parent: base_obj = ' obj = obj@{parent_name}(uint64(5139824614673773682), base_ptr);'.format( parent_name=parent_name) @@ -850,9 +622,6 @@ def wrap_class_constructors(self, namespace_name, inst_class, parent_name, if base_obj: base_obj = '\n' + base_obj - self._debug("class: {}, name: {}".format( - inst_class.name, self._format_class_name(inst_class, - separator="."))) methods_wrap += textwrap.indent(textwrap.dedent('''\ else error('Arguments do not match any overload of {class_name_doc} constructor'); @@ -1101,27 +870,12 @@ def wrap_static_methods(self, namespace_name, instantiated_class, prefix=" ") if serialize: - method_text += textwrap.indent(textwrap.dedent("""\ - function varargout = string_deserialize(varargin) - % STRING_DESERIALIZE usage: string_deserialize() : returns {class_name} - % Doxygen can be found at https://gtsam.org/doxygen/ - if length(varargin) == 1 - varargout{{1}} = {wrapper}({id}, varargin{{:}}); - else - error('Arguments do not match any overload of function {class_name}.string_deserialize'); - end - end\n - function obj = loadobj(sobj) - % LOADOBJ Saves the object to a matlab-readable format - obj = {class_name}.string_deserialize(sobj); - end - """).format( + method_text += WrapperTemplate.matlab_deserialize.format( class_name=namespace_name + '.' + instantiated_class.name, wrapper=self._wrapper_name(), id=self._update_wrapper_id( (namespace_name, instantiated_class, 'string_deserialize', - 'deserialize'))), - prefix=' ') + 'deserialize'))) return method_text @@ -1213,33 +967,32 @@ def wrap_instantiated_class(self, instantiated_class, namespace_name=''): return file_name + '.m', content_text - def wrap_namespace(self, namespace, parent=()): + def wrap_namespace(self, namespace): """Wrap a namespace by wrapping all of its components. Args: namespace: the interface_parser.namespace instance of the namespace parent: parent namespace """ - test_output = '' namespaces = namespace.full_namespaces() inner_namespace = namespace.name != '' wrapped = [] - self._debug("wrapping ns: {}, parent: {}".format( - namespace.full_namespaces(), parent)) - matlab_wrapper = self.generate_matlab_wrapper() - self.content.append((matlab_wrapper[0], matlab_wrapper[1])) + cpp_filename = self._wrapper_name() + '.cpp' + self.content.append((cpp_filename, self.wrapper_file_headers)) current_scope = [] namespace_scope = [] for element in namespace.content: if isinstance(element, parser.Include): - self._add_include(element) + self.includes.append(element) + elif isinstance(element, parser.Namespace): - self.wrap_namespace(element, namespaces) + self.wrap_namespace(element) + elif isinstance(element, instantiator.InstantiatedClass): - self._add_class(element) + self.add_class(element) if inner_namespace: class_text = self.wrap_instantiated_class( @@ -1265,7 +1018,7 @@ def wrap_namespace(self, namespace, parent=()): if isinstance(func, parser.GlobalFunction) ] - test_output += self.wrap_methods(all_funcs, True, global_ns=namespace) + self.wrap_methods(all_funcs, True, global_ns=namespace) return wrapped @@ -1277,16 +1030,12 @@ def wrap_collector_function_shared_return(self, """Wrap the collector function which returns a shared pointer.""" new_line = '\n' if new_line else '' - return textwrap.indent(textwrap.dedent('''\ - {{ - boost::shared_ptr<{name}> shared({shared_obj}); - out[{id}] = wrap_shared_ptr(shared,"{name}"); - }}{new_line}''').format(name=self._format_type_name( - return_type_name, include_namespace=False), - shared_obj=shared_obj, - id=func_id, - new_line=new_line), - prefix=' ') + return WrapperTemplate.collector_function_shared_return.format( + name=self._format_type_name(return_type_name, + include_namespace=False), + shared_obj=shared_obj, + id=func_id, + new_line=new_line) def wrap_collector_function_return_types(self, return_type, func_id): """ @@ -1296,7 +1045,7 @@ def wrap_collector_function_return_types(self, return_type, func_id): pair_value = 'first' if func_id == 0 else 'second' new_line = '\n' if func_id == 0 else '' - if self._is_shared_ptr(return_type) or self._is_ptr(return_type): + if self.is_shared_ptr(return_type) or self.is_ptr(return_type): shared_obj = 'pairResult.' + pair_value if not (return_type.is_shared_ptr or return_type.is_ptr): @@ -1355,16 +1104,12 @@ def wrap_collector_function_return(self, method): method_name = self._format_static_method(method, '::') method_name += method.name - if "MeasureRange" in method_name: - self._debug("method: {}, method: {}, inst: {}".format( - method_name, method.name, method.parent.to_cpp())) - obj = ' ' if return_1_name == 'void' else '' obj += '{}{}({})'.format(obj_start, method_name, params) if return_1_name != 'void': if return_count == 1: - if self._is_shared_ptr(return_1) or self._is_ptr(return_1): + if self.is_shared_ptr(return_1) or self.is_ptr(return_1): sep_method_name = partial(self._format_type_name, return_1.typename, include_namespace=True) @@ -1377,12 +1122,6 @@ def wrap_collector_function_return(self, method): shared_obj = '{obj},"{method_name_sep}"'.format( obj=obj, method_name_sep=sep_method_name('.')) else: - self._debug("Non-PTR: {}, {}".format( - return_1, type(return_1))) - self._debug("Inner type is: {}, {}".format( - return_1.typename.name, sep_method_name('.'))) - self._debug("Inner type instantiations: {}".format( - return_1.typename.instantiations)) method_name_sep_dot = sep_method_name('.') shared_obj_template = 'boost::make_shared<{method_name_sep_col}>({obj}),' \ '"{method_name_sep_dot}"' @@ -1417,16 +1156,8 @@ def wrap_collector_function_upcast_from_void(self, class_name, func_id, """ Add function to upcast type from void type. """ - return textwrap.dedent('''\ - void {class_name}_upcastFromVoid_{id}(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {{ - mexAtExit(&_deleteAllObjects); - typedef boost::shared_ptr<{cpp_name}> Shared; - boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); - out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); - Shared *self = new Shared(boost::static_pointer_cast<{cpp_name}>(*asVoid)); - *reinterpret_cast(mxGetData(out[0])) = self; - }}\n - ''').format(class_name=class_name, cpp_name=cpp_name, id=func_id) + return WrapperTemplate.collector_function_upcast_from_void.format( + class_name=class_name, cpp_name=cpp_name, id=func_id) def generate_collector_function(self, func_id): """ @@ -1610,158 +1341,109 @@ def mex_function(self): else: next_case = None - mex_function = textwrap.dedent(''' - void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) - {{ - mstream mout; - std::streambuf *outbuf = std::cout.rdbuf(&mout);\n - _{module_name}_RTTIRegister();\n - int id = unwrap(in[0]);\n - try {{ - switch(id) {{ - {cases} }} - }} catch(const std::exception& e) {{ - mexErrMsgTxt(("Exception from gtsam:\\n" + std::string(e.what()) + "\\n").c_str()); - }}\n - std::cout.rdbuf(outbuf); - }} - ''').format(module_name=self.module_name, cases=cases) + mex_function = WrapperTemplate.mex_function.format( + module_name=self.module_name, cases=cases) return mex_function - def generate_wrapper(self, namespace): - """Generate the c++ wrapper.""" - # Includes - wrapper_file = self.wrapper_file_header + textwrap.dedent(""" - #include - #include - #include \n - """) - - assert namespace + def get_class_name(self, cls): + """Get the name of the class `cls` taking template instantiations into account.""" + if cls.instantiations: + class_name_sep = cls.name + else: + class_name_sep = cls.to_cpp() - includes_list = sorted(list(self.includes.keys()), - key=lambda include: include.header) + class_name = self._format_class_name(cls) - # Check the number of includes. - # If no includes, do nothing, if 1 then just append newline. - # if more than one, concatenate them with newlines. - if len(includes_list) == 0: - pass - elif len(includes_list) == 1: - wrapper_file += (str(includes_list[0]) + '\n') - else: - wrapper_file += reduce(lambda x, y: str(x) + '\n' + str(y), - includes_list) - wrapper_file += '\n' + return class_name, class_name_sep - typedef_instances = '\n' - typedef_collectors = '' + def generate_preamble(self): + """ + Generate the preamble of the wrapper file, which includes + the Boost exports, typedefs for collectors, and + the _deleteAllObjects and _RTTIRegister functions. + """ + delete_objs = '' + typedef_instances = [] boost_class_export_guid = '' - delete_objs = textwrap.dedent('''\ - void _deleteAllObjects() - { - mstream mout; - std::streambuf *outbuf = std::cout.rdbuf(&mout);\n - bool anyDeleted = false; - ''') - rtti_reg_start = textwrap.dedent('''\ - void _{module_name}_RTTIRegister() {{ - const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_{module_name}_rttiRegistry_created"); - if(!alreadyCreated) {{ - std::map types; - ''').format(module_name=self.module_name) - rtti_reg_mid = '' - rtti_reg_end = textwrap.indent( - textwrap.dedent(''' - mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); - if(!registry) - registry = mxCreateStructMatrix(1, 1, 0, NULL); - typedef std::pair StringPair; - for(const StringPair& rtti_matlab: types) { - int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) - mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); - mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); - mxSetFieldByNumber(registry, 0, fieldId, matlabName); - } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) - mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); - mxDestroyArray(registry); - '''), - prefix=' ') + ' \n' + textwrap.dedent('''\ - mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) - mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); - mxDestroyArray(newAlreadyCreated); - } - } - ''') - ptr_ctor_frag = '' + typedef_collectors = '' + rtti_classes = '' for cls in self.classes: - uninstantiated_name = "::".join( - cls.namespaces()[1:]) + "::" + cls.name - self._debug("Cls: {} -> {}".format(cls.name, uninstantiated_name)) - + # Check if class is in ignore list. + # If so, then skip + uninstantiated_name = "::".join(cls.namespaces()[1:] + [cls.name]) if uninstantiated_name in self.ignore_classes: - self._debug("Ignoring: {} -> {}".format( - cls.name, uninstantiated_name)) continue - def _has_serialization(cls): - for m in cls.methods: - if m.name in self.whitelist: - return True - return False + class_name, class_name_sep = self.get_class_name(cls) + # If a class has instantiations, then declare the typedef for each instance if cls.instantiations: cls_insts = '' - for i, inst in enumerate(cls.instantiations): if i != 0: cls_insts += ', ' cls_insts += self._format_type_name(inst) - typedef_instances += 'typedef {original_class_name} {class_name_sep};\n' \ + typedef_instances.append('typedef {original_class_name} {class_name_sep};' \ .format(original_class_name=cls.to_cpp(), - class_name_sep=cls.name) + class_name_sep=cls.name)) - class_name_sep = cls.name - class_name = self._format_class_name(cls) + # Get the Boost exports for serialization + if cls.original.namespaces() and self._has_serialization(cls): + boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format( + class_name_sep, class_name) - if len(cls.original.namespaces()) > 1 and _has_serialization( - cls): - boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format( - class_name_sep, class_name) - else: - class_name_sep = cls.to_cpp() - class_name = self._format_class_name(cls) - - if len(cls.original.namespaces()) > 1 and _has_serialization( - cls): - boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format( - class_name_sep, class_name) - - typedef_collectors += textwrap.dedent('''\ - typedef std::set*> Collector_{class_name}; - static Collector_{class_name} collector_{class_name}; - ''').format(class_name_sep=class_name_sep, class_name=class_name) - delete_objs += textwrap.indent(textwrap.dedent('''\ - {{ for(Collector_{class_name}::iterator iter = collector_{class_name}.begin(); - iter != collector_{class_name}.end(); ) {{ - delete *iter; - collector_{class_name}.erase(iter++); - anyDeleted = true; - }} }} - ''').format(class_name=class_name), - prefix=' ') + # Typedef and declare the collector objects. + typedef_collectors += WrapperTemplate.typdef_collectors.format( + class_name_sep=class_name_sep, class_name=class_name) + + # Generate the _deleteAllObjects method + delete_objs += WrapperTemplate.delete_obj.format( + class_name=class_name) if cls.is_virtual: - rtti_reg_mid += ' types.insert(std::make_pair(typeid({}).name(), "{}"));\n' \ + class_name, class_name_sep = self.get_class_name(cls) + rtti_classes += ' types.insert(std::make_pair(typeid({}).name(), "{}"));\n' \ .format(class_name_sep, class_name) + # Generate the typedef instances string + typedef_instances = "\n".join(typedef_instances) + + # Generate the full deleteAllObjects function + delete_all_objs = WrapperTemplate.delete_all_objects.format( + delete_objs=delete_objs) + + # Generate the full RTTIRegister function + rtti_register = WrapperTemplate.rtti_register.format( + module_name=self.module_name, rtti_classes=rtti_classes) + + return typedef_instances, boost_class_export_guid, \ + typedef_collectors, delete_all_objs, rtti_register + + def generate_wrapper(self, namespace): + """Generate the c++ wrapper.""" + assert namespace, "Namespace if empty" + + # Generate the header includes + includes_list = sorted(self.includes, + key=lambda include: include.header) + includes = textwrap.dedent("""\ + {wrapper_file_headers} + {boost_headers} + {includes_list} + """).format(wrapper_file_headers=self.wrapper_file_headers.strip(), + boost_headers=WrapperTemplate.boost_headers, + includes_list='\n'.join(map(str, includes_list))) + + preamble = self.generate_preamble() + typedef_instances, boost_class_export_guid, \ + typedef_collectors, delete_all_objs, \ + rtti_register = preamble + + ptr_ctor_frag = '' set_next_case = False for idx in range(self.wrapper_id): @@ -1784,24 +1466,20 @@ def _has_serialization(cls): ptr_ctor_frag += self.wrap_collector_function_upcast_from_void( id_val[1].name, idx, id_val[1].to_cpp()) - wrapper_file += textwrap.dedent('''\ + wrapper_file = textwrap.dedent('''\ + {includes} {typedef_instances} {boost_class_export_guid} {typedefs_collectors} - {delete_objs} if(anyDeleted) - cout << - "WARNING: Wrap modules with variables in the workspace have been reloaded due to\\n" - "calling destructors, call \'clear all\' again if you plan to now recompile a wrap\\n" - "module, so that your recompiled module is used instead of the old one." << endl; - std::cout.rdbuf(outbuf); - }}\n + {delete_all_objs} {rtti_register} {pointer_constructor_fragment}{mex_function}''') \ - .format(typedef_instances=typedef_instances, + .format(includes=includes, + typedef_instances=typedef_instances, boost_class_export_guid=boost_class_export_guid, typedefs_collectors=typedef_collectors, - delete_objs=delete_objs, - rtti_register=rtti_reg_start + rtti_reg_mid + rtti_reg_end, + delete_all_objs=delete_all_objs, + rtti_register=rtti_register, pointer_constructor_fragment=ptr_ctor_frag, mex_function=self.mex_function()) @@ -1815,23 +1493,10 @@ def wrap_class_serialize_method(self, namespace_name, inst_class): wrapper_id = self._update_wrapper_id( (namespace_name, inst_class, 'string_serialize', 'serialize')) - return textwrap.dedent('''\ - function varargout = string_serialize(this, varargin) - % STRING_SERIALIZE usage: string_serialize() : returns string - % Doxygen can be found at https://gtsam.org/doxygen/ - if length(varargin) == 0 - varargout{{1}} = {wrapper}({wrapper_id}, this, varargin{{:}}); - else - error('Arguments do not match any overload of function {class_name}.string_serialize'); - end - end\n - function sobj = saveobj(obj) - % SAVEOBJ Saves the object to a matlab-readable format - sobj = obj.string_serialize(); - end - ''').format(wrapper=self._wrapper_name(), - wrapper_id=wrapper_id, - class_name=namespace_name + '.' + class_name) + return WrapperTemplate.class_serialize_method.format( + wrapper=self._wrapper_name(), + wrapper_id=wrapper_id, + class_name=namespace_name + '.' + class_name) def wrap_collector_function_serialize(self, class_name, @@ -1840,18 +1505,8 @@ def wrap_collector_function_serialize(self, """ Wrap the serizalize collector function. """ - return textwrap.indent(textwrap.dedent("""\ - typedef boost::shared_ptr<{full_name}> Shared; - checkArguments("string_serialize",nargout,nargin-1,0); - Shared obj = unwrap_shared_ptr<{full_name}>(in[0], "ptr_{namespace}{class_name}"); - ostringstream out_archive_stream; - boost::archive::text_oarchive out_archive(out_archive_stream); - out_archive << *obj; - out[0] = wrap< string >(out_archive_stream.str()); - """).format(class_name=class_name, - full_name=full_name, - namespace=namespace), - prefix=' ') + return WrapperTemplate.collector_function_serialize.format( + class_name=class_name, full_name=full_name, namespace=namespace) def wrap_collector_function_deserialize(self, class_name, @@ -1860,87 +1515,85 @@ def wrap_collector_function_deserialize(self, """ Wrap the deserizalize collector function. """ - return textwrap.indent(textwrap.dedent("""\ - typedef boost::shared_ptr<{full_name}> Shared; - checkArguments("{namespace}{class_name}.string_deserialize",nargout,nargin,1); - string serialized = unwrap< string >(in[0]); - istringstream in_archive_stream(serialized); - boost::archive::text_iarchive in_archive(in_archive_stream); - Shared output(new {full_name}()); - in_archive >> *output; - out[0] = wrap_shared_ptr(output,"{namespace}.{class_name}", false); - """).format(class_name=class_name, - full_name=full_name, - namespace=namespace), - prefix=' ') + return WrapperTemplate.collector_function_deserialize.format( + class_name=class_name, full_name=full_name, namespace=namespace) - def wrap(self, content): - """High level function to wrap the project.""" - # Parse the contents of the interface file - parsed_result = parser.Module.parseString(content) - # Instantiate the module - module = instantiator.instantiate_namespace(parsed_result) - self.wrap_namespace(module) - self.generate_wrapper(module) + def generate_content(self, cc_content, path): + """ + Generate files and folders from matlab wrapper content. - return self.content + Args: + cc_content: The content to generate formatted as + (file_name, file_content) or + (folder_name, [(file_name, file_content)]) + path: The path to the files parent folder within the main folder + """ + for c in cc_content: + if isinstance(c, list): + if len(c) == 0: + continue + path_to_folder = osp.join(path, c[0][0]) + + if not osp.isdir(path_to_folder): + try: + os.makedirs(path_to_folder, exist_ok=True) + except OSError: + pass + + for sub_content in c: + self.generate_content(sub_content[1], path_to_folder) + + elif isinstance(c[1], list): + path_to_folder = osp.join(path, c[0]) + + if not osp.isdir(path_to_folder): + try: + os.makedirs(path_to_folder, exist_ok=True) + except OSError: + pass + for sub_content in c[1]: + path_to_file = osp.join(path_to_folder, sub_content[0]) + with open(path_to_file, 'w') as f: + f.write(sub_content[1]) + else: + path_to_file = osp.join(path, c[0]) -def generate_content(cc_content, path, verbose=False): - """ - Generate files and folders from matlab wrapper content. + if not osp.isdir(path_to_file): + try: + os.mkdir(path) + except OSError: + pass - Args: - cc_content: The content to generate formatted as - (file_name, file_content) or - (folder_name, [(file_name, file_content)]) - path: The path to the files parent folder within the main folder - """ - def _debug(message): - if not verbose: - return - print(message, file=sys.stderr) - - for c in cc_content: - if isinstance(c, list): - if len(c) == 0: - continue - _debug("c object: {}".format(c[0][0])) - path_to_folder = osp.join(path, c[0][0]) - - if not os.path.isdir(path_to_folder): - try: - os.makedirs(path_to_folder, exist_ok=True) - except OSError: - pass - - for sub_content in c: - _debug("sub object: {}".format(sub_content[1][0][0])) - generate_content(sub_content[1], path_to_folder) - - elif isinstance(c[1], list): - path_to_folder = osp.join(path, c[0]) - - _debug("[generate_content_global]: {}".format(path_to_folder)) - if not os.path.isdir(path_to_folder): - try: - os.makedirs(path_to_folder, exist_ok=True) - except OSError: - pass - for sub_content in c[1]: - path_to_file = osp.join(path_to_folder, sub_content[0]) - _debug("[generate_global_method]: {}".format(path_to_file)) with open(path_to_file, 'w') as f: - f.write(sub_content[1]) - else: - path_to_file = osp.join(path, c[0]) + f.write(c[1]) - _debug("[generate_content]: {}".format(path_to_file)) - if not os.path.isdir(path_to_file): - try: - os.mkdir(path) - except OSError: - pass + def wrap(self, files, path): + """High level function to wrap the project.""" + modules = {} + for file in files: + with open(file, 'r') as f: + content = f.read() - with open(path_to_file, 'w') as f: - f.write(c[1]) + # Parse the contents of the interface file + parsed_result = parser.Module.parseString(content) + # print(parsed_result) + + # Instantiate the module + module = instantiator.instantiate_namespace(parsed_result) + + if module.name in modules: + modules[module. + name].content[0].content += module.content[0].content + else: + modules[module.name] = module + + for module in modules.values(): + # Wrap the full namespace + self.wrap_namespace(module) + self.generate_wrapper(module) + + # Generate the corresponding .m and .cpp files + self.generate_content(self.content, path) + + return self.content diff --git a/wrap/gtwrap/template_instantiator.py b/wrap/gtwrap/template_instantiator.py index c474246489..87729cfa6d 100644 --- a/wrap/gtwrap/template_instantiator.py +++ b/wrap/gtwrap/template_instantiator.py @@ -4,7 +4,7 @@ import itertools from copy import deepcopy -from typing import List +from typing import Iterable, List import gtwrap.interface_parser as parser @@ -29,12 +29,13 @@ def instantiate_type(ctype: parser.Type, ctype = deepcopy(ctype) # Check if the return type has template parameters - if len(ctype.typename.instantiations) > 0: + if ctype.typename.instantiations: for idx, instantiation in enumerate(ctype.typename.instantiations): if instantiation.name in template_typenames: template_idx = template_typenames.index(instantiation.name) - ctype.typename.instantiations[idx] = instantiations[ - template_idx] + ctype.typename.instantiations[ + idx] = instantiations[ # type: ignore + template_idx] return ctype @@ -212,7 +213,9 @@ class A { void func(X x, Y y); } """ - def __init__(self, original, instantiations: List[parser.Typename] = ''): + def __init__(self, + original, + instantiations: Iterable[parser.Typename] = ()): self.original = original self.instantiations = instantiations self.template = '' @@ -278,7 +281,7 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''): self.original = original self.instantiations = instantiations - self.template = '' + self.template = None self.is_virtual = original.is_virtual self.parent_class = original.parent_class self.parent = original.parent @@ -318,7 +321,7 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''): self.methods = [] for method in instantiated_methods: if not method.template: - self.methods.append(InstantiatedMethod(method, '')) + self.methods.append(InstantiatedMethod(method, ())) else: instantiations = [] # Get all combinations of template parameters @@ -342,9 +345,9 @@ def __init__(self, original: parser.Class, instantiations=(), new_name=''): ) def __repr__(self): - return "{virtual} class {name} [{cpp_class}]: {parent_class}\n"\ - "{ctors}\n{static_methods}\n{methods}".format( - virtual="virtual" if self.is_virtual else '', + return "{virtual}Class {cpp_class} : {parent_class}\n"\ + "{ctors}\n{static_methods}\n{methods}\n{operators}".format( + virtual="virtual " if self.is_virtual else '', name=self.name, cpp_class=self.to_cpp(), parent_class=self.parent, diff --git a/wrap/scripts/matlab_wrap.py b/wrap/scripts/matlab_wrap.py index be6043947f..0f6664a635 100644 --- a/wrap/scripts/matlab_wrap.py +++ b/wrap/scripts/matlab_wrap.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - """ Helper script to wrap C++ to Matlab. This script is installed via CMake to the user's binary directory @@ -7,19 +6,24 @@ """ import argparse -import os import sys -from gtwrap.matlab_wrapper import MatlabWrapper, generate_content +from gtwrap.matlab_wrapper import MatlabWrapper if __name__ == "__main__": arg_parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) - arg_parser.add_argument("--src", type=str, required=True, + arg_parser.add_argument("--src", + type=str, + required=True, help="Input interface .h file.") - arg_parser.add_argument("--module_name", type=str, required=True, + arg_parser.add_argument("--module_name", + type=str, + required=True, help="Name of the C++ class being wrapped.") - arg_parser.add_argument("--out", type=str, required=True, + arg_parser.add_argument("--out", + type=str, + required=True, help="Name of the output folder.") arg_parser.add_argument( "--top_module_namespaces", @@ -33,28 +37,22 @@ "`.Class` of the corresponding C++ `ns1::ns2::ns3::Class`" ", and `from import ns4` gives you access to a Python " "`ns4.Class` of the C++ `ns1::ns2::ns3::ns4::Class`. ") - arg_parser.add_argument("--ignore", - nargs='*', - type=str, - help="A space-separated list of classes to ignore. " - "Class names must include their full namespaces.") + arg_parser.add_argument( + "--ignore", + nargs='*', + type=str, + help="A space-separated list of classes to ignore. " + "Class names must include their full namespaces.") args = arg_parser.parse_args() top_module_namespaces = args.top_module_namespaces.split("::") if top_module_namespaces[0]: top_module_namespaces = [''] + top_module_namespaces - with open(args.src, 'r') as f: - content = f.read() - - if not os.path.exists(args.src): - os.mkdir(args.src) - - print("Ignoring classes: {}".format(args.ignore), file=sys.stderr) + print("[MatlabWrapper] Ignoring classes: {}".format(args.ignore), file=sys.stderr) wrapper = MatlabWrapper(module_name=args.module_name, top_module_namespace=top_module_namespaces, ignore_classes=args.ignore) - cc_content = wrapper.wrap(content) - - generate_content(cc_content, args.out) + sources = args.src.split(';') + cc_content = wrapper.wrap(sources, path=args.out) diff --git a/wrap/tests/expected/matlab/+gtsam/Class1.m b/wrap/tests/expected/matlab/+gtsam/Class1.m new file mode 100644 index 0000000000..00dd5ca746 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/Class1.m @@ -0,0 +1,36 @@ +%class Class1, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%Class1() +% +classdef Class1 < handle + properties + ptr_gtsamClass1 = 0 + end + methods + function obj = Class1(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + multiple_files_wrapper(0, my_ptr); + elseif nargin == 0 + my_ptr = multiple_files_wrapper(1); + else + error('Arguments do not match any overload of gtsam.Class1 constructor'); + end + obj.ptr_gtsamClass1 = my_ptr; + end + + function delete(obj) + multiple_files_wrapper(2, obj.ptr_gtsamClass1); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/Class2.m b/wrap/tests/expected/matlab/+gtsam/Class2.m new file mode 100644 index 0000000000..93279e1560 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/Class2.m @@ -0,0 +1,36 @@ +%class Class2, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%Class2() +% +classdef Class2 < handle + properties + ptr_gtsamClass2 = 0 + end + methods + function obj = Class2(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + multiple_files_wrapper(3, my_ptr); + elseif nargin == 0 + my_ptr = multiple_files_wrapper(4); + else + error('Arguments do not match any overload of gtsam.Class2 constructor'); + end + obj.ptr_gtsamClass2 = my_ptr; + end + + function delete(obj) + multiple_files_wrapper(5, obj.ptr_gtsamClass2); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/ClassA.m b/wrap/tests/expected/matlab/+gtsam/ClassA.m new file mode 100644 index 0000000000..3210e93c60 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/ClassA.m @@ -0,0 +1,36 @@ +%class ClassA, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%ClassA() +% +classdef ClassA < handle + properties + ptr_gtsamClassA = 0 + end + methods + function obj = ClassA(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + multiple_files_wrapper(6, my_ptr); + elseif nargin == 0 + my_ptr = multiple_files_wrapper(7); + else + error('Arguments do not match any overload of gtsam.ClassA constructor'); + end + obj.ptr_gtsamClassA = my_ptr; + end + + function delete(obj) + multiple_files_wrapper(8, obj.ptr_gtsamClassA); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/class_wrapper.cpp b/wrap/tests/expected/matlab/class_wrapper.cpp index e644ac00f2..fab9c14506 100644 --- a/wrap/tests/expected/matlab/class_wrapper.cpp +++ b/wrap/tests/expected/matlab/class_wrapper.cpp @@ -7,7 +7,6 @@ #include - typedef Fun FunDouble; typedef PrimitiveRef PrimitiveRefDouble; typedef MyVector<3> MyVector3; @@ -16,7 +15,6 @@ typedef MultipleTemplates MultipleTemplatesIntDouble; typedef MultipleTemplates MultipleTemplatesIntFloat; typedef MyFactor MyFactorPosePoint2; - typedef std::set*> Collector_FunRange; static Collector_FunRange collector_FunRange; typedef std::set*> Collector_FunDouble; @@ -38,6 +36,7 @@ static Collector_ForwardKinematics collector_ForwardKinematics; typedef std::set*> Collector_MyFactorPosePoint2; static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; + void _deleteAllObjects() { mstream mout; @@ -104,6 +103,7 @@ void _deleteAllObjects() collector_MyFactorPosePoint2.erase(iter++); anyDeleted = true; } } + if(anyDeleted) cout << "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" @@ -117,24 +117,29 @@ void _class_RTTIRegister() { if(!alreadyCreated) { std::map types; + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); if(!registry) registry = mxCreateStructMatrix(1, 1, 0, NULL); typedef std::pair StringPair; for(const StringPair& rtti_matlab: types) { int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) + if(fieldId < 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); mxSetFieldByNumber(registry, 0, fieldId, matlabName); } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(registry); - + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(newAlreadyCreated); } } diff --git a/wrap/tests/expected/matlab/functions_wrapper.cpp b/wrap/tests/expected/matlab/functions_wrapper.cpp index ae7f49c410..d0f0f8ca67 100644 --- a/wrap/tests/expected/matlab/functions_wrapper.cpp +++ b/wrap/tests/expected/matlab/functions_wrapper.cpp @@ -5,38 +5,11 @@ #include #include -#include -typedef Fun FunDouble; -typedef PrimitiveRef PrimitiveRefDouble; -typedef MyVector<3> MyVector3; -typedef MyVector<12> MyVector12; -typedef MultipleTemplates MultipleTemplatesIntDouble; -typedef MultipleTemplates MultipleTemplatesIntFloat; -typedef MyFactor MyFactorPosePoint2; -typedef std::set*> Collector_FunRange; -static Collector_FunRange collector_FunRange; -typedef std::set*> Collector_FunDouble; -static Collector_FunDouble collector_FunDouble; -typedef std::set*> Collector_Test; -static Collector_Test collector_Test; -typedef std::set*> Collector_PrimitiveRefDouble; -static Collector_PrimitiveRefDouble collector_PrimitiveRefDouble; -typedef std::set*> Collector_MyVector3; -static Collector_MyVector3 collector_MyVector3; -typedef std::set*> Collector_MyVector12; -static Collector_MyVector12 collector_MyVector12; -typedef std::set*> Collector_MultipleTemplatesIntDouble; -static Collector_MultipleTemplatesIntDouble collector_MultipleTemplatesIntDouble; -typedef std::set*> Collector_MultipleTemplatesIntFloat; -static Collector_MultipleTemplatesIntFloat collector_MultipleTemplatesIntFloat; -typedef std::set*> Collector_ForwardKinematics; -static Collector_ForwardKinematics collector_ForwardKinematics; -typedef std::set*> Collector_MyFactorPosePoint2; -static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; + void _deleteAllObjects() { @@ -44,66 +17,7 @@ void _deleteAllObjects() std::streambuf *outbuf = std::cout.rdbuf(&mout); bool anyDeleted = false; - { for(Collector_FunRange::iterator iter = collector_FunRange.begin(); - iter != collector_FunRange.end(); ) { - delete *iter; - collector_FunRange.erase(iter++); - anyDeleted = true; - } } - { for(Collector_FunDouble::iterator iter = collector_FunDouble.begin(); - iter != collector_FunDouble.end(); ) { - delete *iter; - collector_FunDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_Test::iterator iter = collector_Test.begin(); - iter != collector_Test.end(); ) { - delete *iter; - collector_Test.erase(iter++); - anyDeleted = true; - } } - { for(Collector_PrimitiveRefDouble::iterator iter = collector_PrimitiveRefDouble.begin(); - iter != collector_PrimitiveRefDouble.end(); ) { - delete *iter; - collector_PrimitiveRefDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector3::iterator iter = collector_MyVector3.begin(); - iter != collector_MyVector3.end(); ) { - delete *iter; - collector_MyVector3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector12::iterator iter = collector_MyVector12.begin(); - iter != collector_MyVector12.end(); ) { - delete *iter; - collector_MyVector12.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntDouble::iterator iter = collector_MultipleTemplatesIntDouble.begin(); - iter != collector_MultipleTemplatesIntDouble.end(); ) { - delete *iter; - collector_MultipleTemplatesIntDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntFloat::iterator iter = collector_MultipleTemplatesIntFloat.begin(); - iter != collector_MultipleTemplatesIntFloat.end(); ) { - delete *iter; - collector_MultipleTemplatesIntFloat.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematics::iterator iter = collector_ForwardKinematics.begin(); - iter != collector_ForwardKinematics.end(); ) { - delete *iter; - collector_ForwardKinematics.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyFactorPosePoint2::iterator iter = collector_MyFactorPosePoint2.begin(); - iter != collector_MyFactorPosePoint2.end(); ) { - delete *iter; - collector_MyFactorPosePoint2.erase(iter++); - anyDeleted = true; - } } + if(anyDeleted) cout << "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" @@ -117,24 +31,29 @@ void _functions_RTTIRegister() { if(!alreadyCreated) { std::map types; + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); if(!registry) registry = mxCreateStructMatrix(1, 1, 0, NULL); typedef std::pair StringPair; for(const StringPair& rtti_matlab: types) { int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) + if(fieldId < 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); mxSetFieldByNumber(registry, 0, fieldId, matlabName); } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(registry); - + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(newAlreadyCreated); } } diff --git a/wrap/tests/expected/matlab/geometry_wrapper.cpp b/wrap/tests/expected/matlab/geometry_wrapper.cpp index 4d8a7c7893..81631390c9 100644 --- a/wrap/tests/expected/matlab/geometry_wrapper.cpp +++ b/wrap/tests/expected/matlab/geometry_wrapper.cpp @@ -5,112 +5,25 @@ #include #include -#include #include #include -typedef Fun FunDouble; -typedef PrimitiveRef PrimitiveRefDouble; -typedef MyVector<3> MyVector3; -typedef MyVector<12> MyVector12; -typedef MultipleTemplates MultipleTemplatesIntDouble; -typedef MultipleTemplates MultipleTemplatesIntFloat; -typedef MyFactor MyFactorPosePoint2; BOOST_CLASS_EXPORT_GUID(gtsam::Point2, "gtsamPoint2"); BOOST_CLASS_EXPORT_GUID(gtsam::Point3, "gtsamPoint3"); -typedef std::set*> Collector_FunRange; -static Collector_FunRange collector_FunRange; -typedef std::set*> Collector_FunDouble; -static Collector_FunDouble collector_FunDouble; -typedef std::set*> Collector_Test; -static Collector_Test collector_Test; -typedef std::set*> Collector_PrimitiveRefDouble; -static Collector_PrimitiveRefDouble collector_PrimitiveRefDouble; -typedef std::set*> Collector_MyVector3; -static Collector_MyVector3 collector_MyVector3; -typedef std::set*> Collector_MyVector12; -static Collector_MyVector12 collector_MyVector12; -typedef std::set*> Collector_MultipleTemplatesIntDouble; -static Collector_MultipleTemplatesIntDouble collector_MultipleTemplatesIntDouble; -typedef std::set*> Collector_MultipleTemplatesIntFloat; -static Collector_MultipleTemplatesIntFloat collector_MultipleTemplatesIntFloat; -typedef std::set*> Collector_ForwardKinematics; -static Collector_ForwardKinematics collector_ForwardKinematics; -typedef std::set*> Collector_MyFactorPosePoint2; -static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; typedef std::set*> Collector_gtsamPoint2; static Collector_gtsamPoint2 collector_gtsamPoint2; typedef std::set*> Collector_gtsamPoint3; static Collector_gtsamPoint3 collector_gtsamPoint3; + void _deleteAllObjects() { mstream mout; std::streambuf *outbuf = std::cout.rdbuf(&mout); bool anyDeleted = false; - { for(Collector_FunRange::iterator iter = collector_FunRange.begin(); - iter != collector_FunRange.end(); ) { - delete *iter; - collector_FunRange.erase(iter++); - anyDeleted = true; - } } - { for(Collector_FunDouble::iterator iter = collector_FunDouble.begin(); - iter != collector_FunDouble.end(); ) { - delete *iter; - collector_FunDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_Test::iterator iter = collector_Test.begin(); - iter != collector_Test.end(); ) { - delete *iter; - collector_Test.erase(iter++); - anyDeleted = true; - } } - { for(Collector_PrimitiveRefDouble::iterator iter = collector_PrimitiveRefDouble.begin(); - iter != collector_PrimitiveRefDouble.end(); ) { - delete *iter; - collector_PrimitiveRefDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector3::iterator iter = collector_MyVector3.begin(); - iter != collector_MyVector3.end(); ) { - delete *iter; - collector_MyVector3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector12::iterator iter = collector_MyVector12.begin(); - iter != collector_MyVector12.end(); ) { - delete *iter; - collector_MyVector12.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntDouble::iterator iter = collector_MultipleTemplatesIntDouble.begin(); - iter != collector_MultipleTemplatesIntDouble.end(); ) { - delete *iter; - collector_MultipleTemplatesIntDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntFloat::iterator iter = collector_MultipleTemplatesIntFloat.begin(); - iter != collector_MultipleTemplatesIntFloat.end(); ) { - delete *iter; - collector_MultipleTemplatesIntFloat.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematics::iterator iter = collector_ForwardKinematics.begin(); - iter != collector_ForwardKinematics.end(); ) { - delete *iter; - collector_ForwardKinematics.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyFactorPosePoint2::iterator iter = collector_MyFactorPosePoint2.begin(); - iter != collector_MyFactorPosePoint2.end(); ) { - delete *iter; - collector_MyFactorPosePoint2.erase(iter++); - anyDeleted = true; - } } { for(Collector_gtsamPoint2::iterator iter = collector_gtsamPoint2.begin(); iter != collector_gtsamPoint2.end(); ) { delete *iter; @@ -123,6 +36,7 @@ void _deleteAllObjects() collector_gtsamPoint3.erase(iter++); anyDeleted = true; } } + if(anyDeleted) cout << "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" @@ -136,24 +50,29 @@ void _geometry_RTTIRegister() { if(!alreadyCreated) { std::map types; + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); if(!registry) registry = mxCreateStructMatrix(1, 1, 0, NULL); typedef std::pair StringPair; for(const StringPair& rtti_matlab: types) { int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) + if(fieldId < 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); mxSetFieldByNumber(registry, 0, fieldId, matlabName); } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(registry); - + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(newAlreadyCreated); } } diff --git a/wrap/tests/expected/matlab/inheritance_wrapper.cpp b/wrap/tests/expected/matlab/inheritance_wrapper.cpp index 077df48305..8e61ac8c61 100644 --- a/wrap/tests/expected/matlab/inheritance_wrapper.cpp +++ b/wrap/tests/expected/matlab/inheritance_wrapper.cpp @@ -5,47 +5,11 @@ #include #include -#include -#include -#include - -typedef Fun FunDouble; -typedef PrimitiveRef PrimitiveRefDouble; -typedef MyVector<3> MyVector3; -typedef MyVector<12> MyVector12; -typedef MultipleTemplates MultipleTemplatesIntDouble; -typedef MultipleTemplates MultipleTemplatesIntFloat; -typedef MyFactor MyFactorPosePoint2; + + typedef MyTemplate MyTemplatePoint2; typedef MyTemplate MyTemplateMatrix; -BOOST_CLASS_EXPORT_GUID(gtsam::Point2, "gtsamPoint2"); -BOOST_CLASS_EXPORT_GUID(gtsam::Point3, "gtsamPoint3"); - -typedef std::set*> Collector_FunRange; -static Collector_FunRange collector_FunRange; -typedef std::set*> Collector_FunDouble; -static Collector_FunDouble collector_FunDouble; -typedef std::set*> Collector_Test; -static Collector_Test collector_Test; -typedef std::set*> Collector_PrimitiveRefDouble; -static Collector_PrimitiveRefDouble collector_PrimitiveRefDouble; -typedef std::set*> Collector_MyVector3; -static Collector_MyVector3 collector_MyVector3; -typedef std::set*> Collector_MyVector12; -static Collector_MyVector12 collector_MyVector12; -typedef std::set*> Collector_MultipleTemplatesIntDouble; -static Collector_MultipleTemplatesIntDouble collector_MultipleTemplatesIntDouble; -typedef std::set*> Collector_MultipleTemplatesIntFloat; -static Collector_MultipleTemplatesIntFloat collector_MultipleTemplatesIntFloat; -typedef std::set*> Collector_ForwardKinematics; -static Collector_ForwardKinematics collector_ForwardKinematics; -typedef std::set*> Collector_MyFactorPosePoint2; -static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; -typedef std::set*> Collector_gtsamPoint2; -static Collector_gtsamPoint2 collector_gtsamPoint2; -typedef std::set*> Collector_gtsamPoint3; -static Collector_gtsamPoint3 collector_gtsamPoint3; typedef std::set*> Collector_MyBase; static Collector_MyBase collector_MyBase; typedef std::set*> Collector_MyTemplatePoint2; @@ -55,84 +19,13 @@ static Collector_MyTemplateMatrix collector_MyTemplateMatrix; typedef std::set*> Collector_ForwardKinematicsFactor; static Collector_ForwardKinematicsFactor collector_ForwardKinematicsFactor; + void _deleteAllObjects() { mstream mout; std::streambuf *outbuf = std::cout.rdbuf(&mout); bool anyDeleted = false; - { for(Collector_FunRange::iterator iter = collector_FunRange.begin(); - iter != collector_FunRange.end(); ) { - delete *iter; - collector_FunRange.erase(iter++); - anyDeleted = true; - } } - { for(Collector_FunDouble::iterator iter = collector_FunDouble.begin(); - iter != collector_FunDouble.end(); ) { - delete *iter; - collector_FunDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_Test::iterator iter = collector_Test.begin(); - iter != collector_Test.end(); ) { - delete *iter; - collector_Test.erase(iter++); - anyDeleted = true; - } } - { for(Collector_PrimitiveRefDouble::iterator iter = collector_PrimitiveRefDouble.begin(); - iter != collector_PrimitiveRefDouble.end(); ) { - delete *iter; - collector_PrimitiveRefDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector3::iterator iter = collector_MyVector3.begin(); - iter != collector_MyVector3.end(); ) { - delete *iter; - collector_MyVector3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector12::iterator iter = collector_MyVector12.begin(); - iter != collector_MyVector12.end(); ) { - delete *iter; - collector_MyVector12.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntDouble::iterator iter = collector_MultipleTemplatesIntDouble.begin(); - iter != collector_MultipleTemplatesIntDouble.end(); ) { - delete *iter; - collector_MultipleTemplatesIntDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntFloat::iterator iter = collector_MultipleTemplatesIntFloat.begin(); - iter != collector_MultipleTemplatesIntFloat.end(); ) { - delete *iter; - collector_MultipleTemplatesIntFloat.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematics::iterator iter = collector_ForwardKinematics.begin(); - iter != collector_ForwardKinematics.end(); ) { - delete *iter; - collector_ForwardKinematics.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyFactorPosePoint2::iterator iter = collector_MyFactorPosePoint2.begin(); - iter != collector_MyFactorPosePoint2.end(); ) { - delete *iter; - collector_MyFactorPosePoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamPoint2::iterator iter = collector_gtsamPoint2.begin(); - iter != collector_gtsamPoint2.end(); ) { - delete *iter; - collector_gtsamPoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamPoint3::iterator iter = collector_gtsamPoint3.begin(); - iter != collector_gtsamPoint3.end(); ) { - delete *iter; - collector_gtsamPoint3.erase(iter++); - anyDeleted = true; - } } { for(Collector_MyBase::iterator iter = collector_MyBase.begin(); iter != collector_MyBase.end(); ) { delete *iter; @@ -157,6 +50,7 @@ void _deleteAllObjects() collector_ForwardKinematicsFactor.erase(iter++); anyDeleted = true; } } + if(anyDeleted) cout << "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" @@ -169,49 +63,54 @@ void _inheritance_RTTIRegister() { const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_inheritance_rttiRegistry_created"); if(!alreadyCreated) { std::map types; + types.insert(std::make_pair(typeid(MyBase).name(), "MyBase")); types.insert(std::make_pair(typeid(MyTemplatePoint2).name(), "MyTemplatePoint2")); types.insert(std::make_pair(typeid(MyTemplateMatrix).name(), "MyTemplateMatrix")); types.insert(std::make_pair(typeid(ForwardKinematicsFactor).name(), "ForwardKinematicsFactor")); + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); if(!registry) registry = mxCreateStructMatrix(1, 1, 0, NULL); typedef std::pair StringPair; for(const StringPair& rtti_matlab: types) { int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) + if(fieldId < 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); mxSetFieldByNumber(registry, 0, fieldId, matlabName); } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(registry); - + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(newAlreadyCreated); } } -void gtsamPoint2_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyBase_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); - typedef boost::shared_ptr Shared; + typedef boost::shared_ptr Shared; Shared *self = *reinterpret_cast (mxGetData(in[0])); - collector_gtsamPoint2.insert(self); + collector_MyBase.insert(self); } -void MyBase_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[]) -{ +void MyBase_upcastFromVoid_1(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; - - Shared *self = *reinterpret_cast (mxGetData(in[0])); - collector_MyBase.insert(self); + boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(boost::static_pointer_cast(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; } void MyBase_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -227,19 +126,6 @@ void MyBase_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArr } } -void gtsamPoint2_deconstructor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) -{ - typedef boost::shared_ptr Shared; - checkArguments("delete_gtsamPoint2",nargout,nargin,1); - Shared *self = *reinterpret_cast(mxGetData(in[0])); - Collector_gtsamPoint2::iterator item; - item = collector_gtsamPoint2.find(self); - if(item != collector_gtsamPoint2.end()) { - delete self; - collector_gtsamPoint2.erase(item); - } -} - void MyTemplatePoint2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); @@ -253,6 +139,15 @@ void MyTemplatePoint2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], *reinterpret_cast(mxGetData(out[0])) = new SharedBase(*self); } +void MyTemplatePoint2_upcastFromVoid_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr> Shared; + boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(boost::static_pointer_cast>(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; +} + void MyTemplatePoint2_constructor_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); @@ -399,20 +294,6 @@ void MyTemplatePoint2_Level_18(int nargout, mxArray *out[], int nargin, const mx out[0] = wrap_shared_ptr(boost::make_shared>(MyTemplate::Level(K)),"MyTemplatePoint2", false); } -void gtsamPoint3_constructor_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) -{ - mexAtExit(&_deleteAllObjects); - typedef boost::shared_ptr Shared; - - double x = unwrap< double >(in[0]); - double y = unwrap< double >(in[1]); - double z = unwrap< double >(in[2]); - Shared *self = new Shared(new gtsam::Point3(x,y,z)); - collector_gtsamPoint3.insert(self); - out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); - *reinterpret_cast (mxGetData(out[0])) = self; -} - void MyTemplateMatrix_collectorInsertAndMakeBase_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); @@ -426,6 +307,15 @@ void MyTemplateMatrix_collectorInsertAndMakeBase_19(int nargout, mxArray *out[], *reinterpret_cast(mxGetData(out[0])) = new SharedBase(*self); } +void MyTemplateMatrix_upcastFromVoid_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr> Shared; + boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(boost::static_pointer_cast>(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; +} + void MyTemplateMatrix_constructor_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); @@ -572,14 +462,6 @@ void MyTemplateMatrix_Level_34(int nargout, mxArray *out[], int nargin, const mx out[0] = wrap_shared_ptr(boost::make_shared>(MyTemplate::Level(K)),"MyTemplateMatrix", false); } -void Test_return_vector2_35(int nargout, mxArray *out[], int nargin, const mxArray *in[]) -{ - checkArguments("return_vector2",nargout,nargin-1,1); - auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); - Vector value = unwrap< Vector >(in[1]); - out[0] = wrap< Vector >(obj->return_vector2(value)); -} - void ForwardKinematicsFactor_collectorInsertAndMakeBase_35(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); @@ -593,6 +475,15 @@ void ForwardKinematicsFactor_collectorInsertAndMakeBase_35(int nargout, mxArray *reinterpret_cast(mxGetData(out[0])) = new SharedBase(*self); } +void ForwardKinematicsFactor_upcastFromVoid_36(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(boost::static_pointer_cast(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; +} + void ForwardKinematicsFactor_deconstructor_37(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; @@ -619,19 +510,19 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) try { switch(id) { case 0: - gtsamPoint2_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1); + MyBase_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1); break; case 1: - MyBase_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1); + MyBase_upcastFromVoid_1(nargout, out, nargin-1, in+1); break; case 2: MyBase_deconstructor_2(nargout, out, nargin-1, in+1); break; case 3: - gtsamPoint2_deconstructor_3(nargout, out, nargin-1, in+1); + MyTemplatePoint2_collectorInsertAndMakeBase_3(nargout, out, nargin-1, in+1); break; case 4: - MyTemplatePoint2_collectorInsertAndMakeBase_3(nargout, out, nargin-1, in+1); + MyTemplatePoint2_upcastFromVoid_4(nargout, out, nargin-1, in+1); break; case 5: MyTemplatePoint2_constructor_5(nargout, out, nargin-1, in+1); @@ -676,10 +567,10 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) MyTemplatePoint2_Level_18(nargout, out, nargin-1, in+1); break; case 19: - gtsamPoint3_constructor_19(nargout, out, nargin-1, in+1); + MyTemplateMatrix_collectorInsertAndMakeBase_19(nargout, out, nargin-1, in+1); break; case 20: - MyTemplateMatrix_collectorInsertAndMakeBase_19(nargout, out, nargin-1, in+1); + MyTemplateMatrix_upcastFromVoid_20(nargout, out, nargin-1, in+1); break; case 21: MyTemplateMatrix_constructor_21(nargout, out, nargin-1, in+1); @@ -724,10 +615,10 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) MyTemplateMatrix_Level_34(nargout, out, nargin-1, in+1); break; case 35: - Test_return_vector2_35(nargout, out, nargin-1, in+1); + ForwardKinematicsFactor_collectorInsertAndMakeBase_35(nargout, out, nargin-1, in+1); break; case 36: - ForwardKinematicsFactor_collectorInsertAndMakeBase_35(nargout, out, nargin-1, in+1); + ForwardKinematicsFactor_upcastFromVoid_36(nargout, out, nargin-1, in+1); break; case 37: ForwardKinematicsFactor_deconstructor_37(nargout, out, nargin-1, in+1); diff --git a/wrap/tests/expected/matlab/multiple_files_wrapper.cpp b/wrap/tests/expected/matlab/multiple_files_wrapper.cpp new file mode 100644 index 0000000000..66ab7ff73d --- /dev/null +++ b/wrap/tests/expected/matlab/multiple_files_wrapper.cpp @@ -0,0 +1,229 @@ +#include +#include + +#include +#include +#include + + + + + +typedef std::set*> Collector_gtsamClass1; +static Collector_gtsamClass1 collector_gtsamClass1; +typedef std::set*> Collector_gtsamClass2; +static Collector_gtsamClass2 collector_gtsamClass2; +typedef std::set*> Collector_gtsamClassA; +static Collector_gtsamClassA collector_gtsamClassA; + + +void _deleteAllObjects() +{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout); + + bool anyDeleted = false; + { for(Collector_gtsamClass1::iterator iter = collector_gtsamClass1.begin(); + iter != collector_gtsamClass1.end(); ) { + delete *iter; + collector_gtsamClass1.erase(iter++); + anyDeleted = true; + } } + { for(Collector_gtsamClass2::iterator iter = collector_gtsamClass2.begin(); + iter != collector_gtsamClass2.end(); ) { + delete *iter; + collector_gtsamClass2.erase(iter++); + anyDeleted = true; + } } + { for(Collector_gtsamClassA::iterator iter = collector_gtsamClassA.begin(); + iter != collector_gtsamClassA.end(); ) { + delete *iter; + collector_gtsamClassA.erase(iter++); + anyDeleted = true; + } } + + if(anyDeleted) + cout << + "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" + "calling destructors, call 'clear all' again if you plan to now recompile a wrap\n" + "module, so that your recompiled module is used instead of the old one." << endl; + std::cout.rdbuf(outbuf); +} + +void _multiple_files_RTTIRegister() { + const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_multiple_files_rttiRegistry_created"); + if(!alreadyCreated) { + std::map types; + + + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); + if(!registry) + registry = mxCreateStructMatrix(1, 1, 0, NULL); + typedef std::pair StringPair; + for(const StringPair& rtti_matlab: types) { + int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); + if(fieldId < 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); + mxSetFieldByNumber(registry, 0, fieldId, matlabName); + } + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxDestroyArray(registry); + + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxDestroyArray(newAlreadyCreated); + } +} + +void gtsamClass1_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamClass1.insert(self); +} + +void gtsamClass1_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + Shared *self = new Shared(new gtsam::Class1()); + collector_gtsamClass1.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamClass1_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef boost::shared_ptr Shared; + checkArguments("delete_gtsamClass1",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamClass1::iterator item; + item = collector_gtsamClass1.find(self); + if(item != collector_gtsamClass1.end()) { + delete self; + collector_gtsamClass1.erase(item); + } +} + +void gtsamClass2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamClass2.insert(self); +} + +void gtsamClass2_constructor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + Shared *self = new Shared(new gtsam::Class2()); + collector_gtsamClass2.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamClass2_deconstructor_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef boost::shared_ptr Shared; + checkArguments("delete_gtsamClass2",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamClass2::iterator item; + item = collector_gtsamClass2.find(self); + if(item != collector_gtsamClass2.end()) { + delete self; + collector_gtsamClass2.erase(item); + } +} + +void gtsamClassA_collectorInsertAndMakeBase_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamClassA.insert(self); +} + +void gtsamClassA_constructor_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + Shared *self = new Shared(new gtsam::ClassA()); + collector_gtsamClassA.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamClassA_deconstructor_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef boost::shared_ptr Shared; + checkArguments("delete_gtsamClassA",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamClassA::iterator item; + item = collector_gtsamClassA.find(self); + if(item != collector_gtsamClassA.end()) { + delete self; + collector_gtsamClassA.erase(item); + } +} + + +void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout); + + _multiple_files_RTTIRegister(); + + int id = unwrap(in[0]); + + try { + switch(id) { + case 0: + gtsamClass1_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1); + break; + case 1: + gtsamClass1_constructor_1(nargout, out, nargin-1, in+1); + break; + case 2: + gtsamClass1_deconstructor_2(nargout, out, nargin-1, in+1); + break; + case 3: + gtsamClass2_collectorInsertAndMakeBase_3(nargout, out, nargin-1, in+1); + break; + case 4: + gtsamClass2_constructor_4(nargout, out, nargin-1, in+1); + break; + case 5: + gtsamClass2_deconstructor_5(nargout, out, nargin-1, in+1); + break; + case 6: + gtsamClassA_collectorInsertAndMakeBase_6(nargout, out, nargin-1, in+1); + break; + case 7: + gtsamClassA_constructor_7(nargout, out, nargin-1, in+1); + break; + case 8: + gtsamClassA_deconstructor_8(nargout, out, nargin-1, in+1); + break; + } + } catch(const std::exception& e) { + mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str()); + } + + std::cout.rdbuf(outbuf); +} diff --git a/wrap/tests/expected/matlab/namespaces_wrapper.cpp b/wrap/tests/expected/matlab/namespaces_wrapper.cpp index 8f6e415e2d..604ede5da5 100644 --- a/wrap/tests/expected/matlab/namespaces_wrapper.cpp +++ b/wrap/tests/expected/matlab/namespaces_wrapper.cpp @@ -5,9 +5,6 @@ #include #include -#include -#include -#include #include #include #include @@ -15,51 +12,8 @@ #include #include -typedef Fun FunDouble; -typedef PrimitiveRef PrimitiveRefDouble; -typedef MyVector<3> MyVector3; -typedef MyVector<12> MyVector12; -typedef MultipleTemplates MultipleTemplatesIntDouble; -typedef MultipleTemplates MultipleTemplatesIntFloat; -typedef MyFactor MyFactorPosePoint2; -typedef MyTemplate MyTemplatePoint2; -typedef MyTemplate MyTemplateMatrix; - -BOOST_CLASS_EXPORT_GUID(gtsam::Point2, "gtsamPoint2"); -BOOST_CLASS_EXPORT_GUID(gtsam::Point3, "gtsamPoint3"); - -typedef std::set*> Collector_FunRange; -static Collector_FunRange collector_FunRange; -typedef std::set*> Collector_FunDouble; -static Collector_FunDouble collector_FunDouble; -typedef std::set*> Collector_Test; -static Collector_Test collector_Test; -typedef std::set*> Collector_PrimitiveRefDouble; -static Collector_PrimitiveRefDouble collector_PrimitiveRefDouble; -typedef std::set*> Collector_MyVector3; -static Collector_MyVector3 collector_MyVector3; -typedef std::set*> Collector_MyVector12; -static Collector_MyVector12 collector_MyVector12; -typedef std::set*> Collector_MultipleTemplatesIntDouble; -static Collector_MultipleTemplatesIntDouble collector_MultipleTemplatesIntDouble; -typedef std::set*> Collector_MultipleTemplatesIntFloat; -static Collector_MultipleTemplatesIntFloat collector_MultipleTemplatesIntFloat; -typedef std::set*> Collector_ForwardKinematics; -static Collector_ForwardKinematics collector_ForwardKinematics; -typedef std::set*> Collector_MyFactorPosePoint2; -static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; -typedef std::set*> Collector_gtsamPoint2; -static Collector_gtsamPoint2 collector_gtsamPoint2; -typedef std::set*> Collector_gtsamPoint3; -static Collector_gtsamPoint3 collector_gtsamPoint3; -typedef std::set*> Collector_MyBase; -static Collector_MyBase collector_MyBase; -typedef std::set*> Collector_MyTemplatePoint2; -static Collector_MyTemplatePoint2 collector_MyTemplatePoint2; -typedef std::set*> Collector_MyTemplateMatrix; -static Collector_MyTemplateMatrix collector_MyTemplateMatrix; -typedef std::set*> Collector_ForwardKinematicsFactor; -static Collector_ForwardKinematicsFactor collector_ForwardKinematicsFactor; + + typedef std::set*> Collector_ns1ClassA; static Collector_ns1ClassA collector_ns1ClassA; typedef std::set*> Collector_ns1ClassB; @@ -75,108 +29,13 @@ static Collector_ClassD collector_ClassD; typedef std::set*> Collector_gtsamValues; static Collector_gtsamValues collector_gtsamValues; + void _deleteAllObjects() { mstream mout; std::streambuf *outbuf = std::cout.rdbuf(&mout); bool anyDeleted = false; - { for(Collector_FunRange::iterator iter = collector_FunRange.begin(); - iter != collector_FunRange.end(); ) { - delete *iter; - collector_FunRange.erase(iter++); - anyDeleted = true; - } } - { for(Collector_FunDouble::iterator iter = collector_FunDouble.begin(); - iter != collector_FunDouble.end(); ) { - delete *iter; - collector_FunDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_Test::iterator iter = collector_Test.begin(); - iter != collector_Test.end(); ) { - delete *iter; - collector_Test.erase(iter++); - anyDeleted = true; - } } - { for(Collector_PrimitiveRefDouble::iterator iter = collector_PrimitiveRefDouble.begin(); - iter != collector_PrimitiveRefDouble.end(); ) { - delete *iter; - collector_PrimitiveRefDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector3::iterator iter = collector_MyVector3.begin(); - iter != collector_MyVector3.end(); ) { - delete *iter; - collector_MyVector3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector12::iterator iter = collector_MyVector12.begin(); - iter != collector_MyVector12.end(); ) { - delete *iter; - collector_MyVector12.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntDouble::iterator iter = collector_MultipleTemplatesIntDouble.begin(); - iter != collector_MultipleTemplatesIntDouble.end(); ) { - delete *iter; - collector_MultipleTemplatesIntDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntFloat::iterator iter = collector_MultipleTemplatesIntFloat.begin(); - iter != collector_MultipleTemplatesIntFloat.end(); ) { - delete *iter; - collector_MultipleTemplatesIntFloat.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematics::iterator iter = collector_ForwardKinematics.begin(); - iter != collector_ForwardKinematics.end(); ) { - delete *iter; - collector_ForwardKinematics.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyFactorPosePoint2::iterator iter = collector_MyFactorPosePoint2.begin(); - iter != collector_MyFactorPosePoint2.end(); ) { - delete *iter; - collector_MyFactorPosePoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamPoint2::iterator iter = collector_gtsamPoint2.begin(); - iter != collector_gtsamPoint2.end(); ) { - delete *iter; - collector_gtsamPoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamPoint3::iterator iter = collector_gtsamPoint3.begin(); - iter != collector_gtsamPoint3.end(); ) { - delete *iter; - collector_gtsamPoint3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyBase::iterator iter = collector_MyBase.begin(); - iter != collector_MyBase.end(); ) { - delete *iter; - collector_MyBase.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyTemplatePoint2::iterator iter = collector_MyTemplatePoint2.begin(); - iter != collector_MyTemplatePoint2.end(); ) { - delete *iter; - collector_MyTemplatePoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyTemplateMatrix::iterator iter = collector_MyTemplateMatrix.begin(); - iter != collector_MyTemplateMatrix.end(); ) { - delete *iter; - collector_MyTemplateMatrix.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematicsFactor::iterator iter = collector_ForwardKinematicsFactor.begin(); - iter != collector_ForwardKinematicsFactor.end(); ) { - delete *iter; - collector_ForwardKinematicsFactor.erase(iter++); - anyDeleted = true; - } } { for(Collector_ns1ClassA::iterator iter = collector_ns1ClassA.begin(); iter != collector_ns1ClassA.end(); ) { delete *iter; @@ -219,6 +78,7 @@ void _deleteAllObjects() collector_gtsamValues.erase(iter++); anyDeleted = true; } } + if(anyDeleted) cout << "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" @@ -231,10 +91,8 @@ void _namespaces_RTTIRegister() { const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_namespaces_rttiRegistry_created"); if(!alreadyCreated) { std::map types; - types.insert(std::make_pair(typeid(MyBase).name(), "MyBase")); - types.insert(std::make_pair(typeid(MyTemplatePoint2).name(), "MyTemplatePoint2")); - types.insert(std::make_pair(typeid(MyTemplateMatrix).name(), "MyTemplateMatrix")); - types.insert(std::make_pair(typeid(ForwardKinematicsFactor).name(), "ForwardKinematicsFactor")); + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); if(!registry) @@ -242,18 +100,21 @@ void _namespaces_RTTIRegister() { typedef std::pair StringPair; for(const StringPair& rtti_matlab: types) { int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) + if(fieldId < 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); mxSetFieldByNumber(registry, 0, fieldId, matlabName); } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(registry); - + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(newAlreadyCreated); } } diff --git a/wrap/tests/expected/matlab/special_cases_wrapper.cpp b/wrap/tests/expected/matlab/special_cases_wrapper.cpp index 056ce80973..69abbf73be 100644 --- a/wrap/tests/expected/matlab/special_cases_wrapper.cpp +++ b/wrap/tests/expected/matlab/special_cases_wrapper.cpp @@ -5,78 +5,11 @@ #include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include - -typedef Fun FunDouble; -typedef PrimitiveRef PrimitiveRefDouble; -typedef MyVector<3> MyVector3; -typedef MyVector<12> MyVector12; -typedef MultipleTemplates MultipleTemplatesIntDouble; -typedef MultipleTemplates MultipleTemplatesIntFloat; -typedef MyFactor MyFactorPosePoint2; -typedef MyTemplate MyTemplatePoint2; -typedef MyTemplate MyTemplateMatrix; + typedef gtsam::PinholeCamera PinholeCameraCal3Bundler; typedef gtsam::GeneralSFMFactor, gtsam::Point3> GeneralSFMFactorCal3Bundler; -BOOST_CLASS_EXPORT_GUID(gtsam::Point2, "gtsamPoint2"); -BOOST_CLASS_EXPORT_GUID(gtsam::Point3, "gtsamPoint3"); - -typedef std::set*> Collector_FunRange; -static Collector_FunRange collector_FunRange; -typedef std::set*> Collector_FunDouble; -static Collector_FunDouble collector_FunDouble; -typedef std::set*> Collector_Test; -static Collector_Test collector_Test; -typedef std::set*> Collector_PrimitiveRefDouble; -static Collector_PrimitiveRefDouble collector_PrimitiveRefDouble; -typedef std::set*> Collector_MyVector3; -static Collector_MyVector3 collector_MyVector3; -typedef std::set*> Collector_MyVector12; -static Collector_MyVector12 collector_MyVector12; -typedef std::set*> Collector_MultipleTemplatesIntDouble; -static Collector_MultipleTemplatesIntDouble collector_MultipleTemplatesIntDouble; -typedef std::set*> Collector_MultipleTemplatesIntFloat; -static Collector_MultipleTemplatesIntFloat collector_MultipleTemplatesIntFloat; -typedef std::set*> Collector_ForwardKinematics; -static Collector_ForwardKinematics collector_ForwardKinematics; -typedef std::set*> Collector_MyFactorPosePoint2; -static Collector_MyFactorPosePoint2 collector_MyFactorPosePoint2; -typedef std::set*> Collector_gtsamPoint2; -static Collector_gtsamPoint2 collector_gtsamPoint2; -typedef std::set*> Collector_gtsamPoint3; -static Collector_gtsamPoint3 collector_gtsamPoint3; -typedef std::set*> Collector_MyBase; -static Collector_MyBase collector_MyBase; -typedef std::set*> Collector_MyTemplatePoint2; -static Collector_MyTemplatePoint2 collector_MyTemplatePoint2; -typedef std::set*> Collector_MyTemplateMatrix; -static Collector_MyTemplateMatrix collector_MyTemplateMatrix; -typedef std::set*> Collector_ForwardKinematicsFactor; -static Collector_ForwardKinematicsFactor collector_ForwardKinematicsFactor; -typedef std::set*> Collector_ns1ClassA; -static Collector_ns1ClassA collector_ns1ClassA; -typedef std::set*> Collector_ns1ClassB; -static Collector_ns1ClassB collector_ns1ClassB; -typedef std::set*> Collector_ns2ClassA; -static Collector_ns2ClassA collector_ns2ClassA; -typedef std::set*> Collector_ns2ns3ClassB; -static Collector_ns2ns3ClassB collector_ns2ns3ClassB; -typedef std::set*> Collector_ns2ClassC; -static Collector_ns2ClassC collector_ns2ClassC; -typedef std::set*> Collector_ClassD; -static Collector_ClassD collector_ClassD; -typedef std::set*> Collector_gtsamValues; -static Collector_gtsamValues collector_gtsamValues; typedef std::set*> Collector_gtsamNonlinearFactorGraph; static Collector_gtsamNonlinearFactorGraph collector_gtsamNonlinearFactorGraph; typedef std::set*> Collector_gtsamSfmTrack; @@ -86,150 +19,13 @@ static Collector_gtsamPinholeCameraCal3Bundler collector_gtsamPinholeCameraCal3B typedef std::set*> Collector_gtsamGeneralSFMFactorCal3Bundler; static Collector_gtsamGeneralSFMFactorCal3Bundler collector_gtsamGeneralSFMFactorCal3Bundler; + void _deleteAllObjects() { mstream mout; std::streambuf *outbuf = std::cout.rdbuf(&mout); bool anyDeleted = false; - { for(Collector_FunRange::iterator iter = collector_FunRange.begin(); - iter != collector_FunRange.end(); ) { - delete *iter; - collector_FunRange.erase(iter++); - anyDeleted = true; - } } - { for(Collector_FunDouble::iterator iter = collector_FunDouble.begin(); - iter != collector_FunDouble.end(); ) { - delete *iter; - collector_FunDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_Test::iterator iter = collector_Test.begin(); - iter != collector_Test.end(); ) { - delete *iter; - collector_Test.erase(iter++); - anyDeleted = true; - } } - { for(Collector_PrimitiveRefDouble::iterator iter = collector_PrimitiveRefDouble.begin(); - iter != collector_PrimitiveRefDouble.end(); ) { - delete *iter; - collector_PrimitiveRefDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector3::iterator iter = collector_MyVector3.begin(); - iter != collector_MyVector3.end(); ) { - delete *iter; - collector_MyVector3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyVector12::iterator iter = collector_MyVector12.begin(); - iter != collector_MyVector12.end(); ) { - delete *iter; - collector_MyVector12.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntDouble::iterator iter = collector_MultipleTemplatesIntDouble.begin(); - iter != collector_MultipleTemplatesIntDouble.end(); ) { - delete *iter; - collector_MultipleTemplatesIntDouble.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MultipleTemplatesIntFloat::iterator iter = collector_MultipleTemplatesIntFloat.begin(); - iter != collector_MultipleTemplatesIntFloat.end(); ) { - delete *iter; - collector_MultipleTemplatesIntFloat.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematics::iterator iter = collector_ForwardKinematics.begin(); - iter != collector_ForwardKinematics.end(); ) { - delete *iter; - collector_ForwardKinematics.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyFactorPosePoint2::iterator iter = collector_MyFactorPosePoint2.begin(); - iter != collector_MyFactorPosePoint2.end(); ) { - delete *iter; - collector_MyFactorPosePoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamPoint2::iterator iter = collector_gtsamPoint2.begin(); - iter != collector_gtsamPoint2.end(); ) { - delete *iter; - collector_gtsamPoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamPoint3::iterator iter = collector_gtsamPoint3.begin(); - iter != collector_gtsamPoint3.end(); ) { - delete *iter; - collector_gtsamPoint3.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyBase::iterator iter = collector_MyBase.begin(); - iter != collector_MyBase.end(); ) { - delete *iter; - collector_MyBase.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyTemplatePoint2::iterator iter = collector_MyTemplatePoint2.begin(); - iter != collector_MyTemplatePoint2.end(); ) { - delete *iter; - collector_MyTemplatePoint2.erase(iter++); - anyDeleted = true; - } } - { for(Collector_MyTemplateMatrix::iterator iter = collector_MyTemplateMatrix.begin(); - iter != collector_MyTemplateMatrix.end(); ) { - delete *iter; - collector_MyTemplateMatrix.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ForwardKinematicsFactor::iterator iter = collector_ForwardKinematicsFactor.begin(); - iter != collector_ForwardKinematicsFactor.end(); ) { - delete *iter; - collector_ForwardKinematicsFactor.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ns1ClassA::iterator iter = collector_ns1ClassA.begin(); - iter != collector_ns1ClassA.end(); ) { - delete *iter; - collector_ns1ClassA.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ns1ClassB::iterator iter = collector_ns1ClassB.begin(); - iter != collector_ns1ClassB.end(); ) { - delete *iter; - collector_ns1ClassB.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ns2ClassA::iterator iter = collector_ns2ClassA.begin(); - iter != collector_ns2ClassA.end(); ) { - delete *iter; - collector_ns2ClassA.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ns2ns3ClassB::iterator iter = collector_ns2ns3ClassB.begin(); - iter != collector_ns2ns3ClassB.end(); ) { - delete *iter; - collector_ns2ns3ClassB.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ns2ClassC::iterator iter = collector_ns2ClassC.begin(); - iter != collector_ns2ClassC.end(); ) { - delete *iter; - collector_ns2ClassC.erase(iter++); - anyDeleted = true; - } } - { for(Collector_ClassD::iterator iter = collector_ClassD.begin(); - iter != collector_ClassD.end(); ) { - delete *iter; - collector_ClassD.erase(iter++); - anyDeleted = true; - } } - { for(Collector_gtsamValues::iterator iter = collector_gtsamValues.begin(); - iter != collector_gtsamValues.end(); ) { - delete *iter; - collector_gtsamValues.erase(iter++); - anyDeleted = true; - } } { for(Collector_gtsamNonlinearFactorGraph::iterator iter = collector_gtsamNonlinearFactorGraph.begin(); iter != collector_gtsamNonlinearFactorGraph.end(); ) { delete *iter; @@ -254,6 +50,7 @@ void _deleteAllObjects() collector_gtsamGeneralSFMFactorCal3Bundler.erase(iter++); anyDeleted = true; } } + if(anyDeleted) cout << "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" @@ -266,10 +63,8 @@ void _special_cases_RTTIRegister() { const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_special_cases_rttiRegistry_created"); if(!alreadyCreated) { std::map types; - types.insert(std::make_pair(typeid(MyBase).name(), "MyBase")); - types.insert(std::make_pair(typeid(MyTemplatePoint2).name(), "MyTemplatePoint2")); - types.insert(std::make_pair(typeid(MyTemplateMatrix).name(), "MyTemplateMatrix")); - types.insert(std::make_pair(typeid(ForwardKinematicsFactor).name(), "ForwardKinematicsFactor")); + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); if(!registry) @@ -277,18 +72,21 @@ void _special_cases_RTTIRegister() { typedef std::pair StringPair; for(const StringPair& rtti_matlab: types) { int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); - if(fieldId < 0) + if(fieldId < 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); mxSetFieldByNumber(registry, 0, fieldId, matlabName); } - if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(registry); - + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) + if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } mxDestroyArray(newAlreadyCreated); } } diff --git a/wrap/tests/fixtures/part1.i b/wrap/tests/fixtures/part1.i new file mode 100644 index 0000000000..b69850baff --- /dev/null +++ b/wrap/tests/fixtures/part1.i @@ -0,0 +1,11 @@ +// First file to test for multi-file support. + +namespace gtsam { +class Class1 { + Class1(); +}; + +class Class2 { + Class2(); +}; +} // namespace gtsam \ No newline at end of file diff --git a/wrap/tests/fixtures/part2.i b/wrap/tests/fixtures/part2.i new file mode 100644 index 0000000000..29ad86a7f8 --- /dev/null +++ b/wrap/tests/fixtures/part2.i @@ -0,0 +1,7 @@ +// Second file to test for multi-file support. + +namespace gtsam { +class ClassA { + ClassA(); +}; +} // namespace gtsam \ No newline at end of file diff --git a/wrap/tests/test_matlab_wrapper.py b/wrap/tests/test_matlab_wrapper.py index b321c4e151..fad4de16a6 100644 --- a/wrap/tests/test_matlab_wrapper.py +++ b/wrap/tests/test_matlab_wrapper.py @@ -22,73 +22,31 @@ class TestWrap(unittest.TestCase): """ Test the Matlab wrapper """ - TEST_DIR = osp.dirname(osp.realpath(__file__)) - INTERFACE_DIR = osp.join(TEST_DIR, "fixtures") - MATLAB_TEST_DIR = osp.join(TEST_DIR, "expected", "matlab") - MATLAB_ACTUAL_DIR = osp.join(TEST_DIR, "actual", "matlab") - - # Create the `actual/matlab` directory - os.makedirs(MATLAB_ACTUAL_DIR, exist_ok=True) - - # set the log level to INFO by default - logger.remove() # remove the default sink - logger.add(sys.stderr, format="{time} {level} {message}", level="INFO") - - def generate_content(self, cc_content, path=MATLAB_ACTUAL_DIR): - """Generate files and folders from matlab wrapper content. - - Keyword arguments: - cc_content -- the content to generate formatted as - (file_name, file_content) or - (folder_name, [(file_name, file_content)]) - path -- the path to the files parent folder within the main folder - """ - for c in cc_content: - if isinstance(c, list): - if len(c) == 0: - continue - logger.debug("c object: {}".format(c[0][0])) - path_to_folder = osp.join(path, c[0][0]) - - if not osp.isdir(path_to_folder): - try: - os.makedirs(path_to_folder, exist_ok=True) - except OSError: - pass - - for sub_content in c: - logger.debug("sub object: {}".format(sub_content[1][0][0])) - self.generate_content(sub_content[1], path_to_folder) - - elif isinstance(c[1], list): - path_to_folder = osp.join(path, c[0]) - - logger.debug( - "[generate_content_global]: {}".format(path_to_folder)) - if not osp.isdir(path_to_folder): - try: - os.makedirs(path_to_folder, exist_ok=True) - except OSError: - pass - for sub_content in c[1]: - path_to_file = osp.join(path_to_folder, sub_content[0]) - logger.debug( - "[generate_global_method]: {}".format(path_to_file)) - with open(path_to_file, 'w') as f: - f.write(sub_content[1]) - - else: - path_to_file = osp.join(path, c[0]) - - logger.debug("[generate_content]: {}".format(path_to_file)) - if not osp.isdir(path_to_file): - try: - os.mkdir(path) - except OSError: - pass - - with open(path_to_file, 'w') as f: - f.write(c[1]) + def setUp(self) -> None: + super().setUp() + + # Set up all the directories + self.TEST_DIR = osp.dirname(osp.realpath(__file__)) + self.INTERFACE_DIR = osp.join(self.TEST_DIR, "fixtures") + self.MATLAB_TEST_DIR = osp.join(self.TEST_DIR, "expected", "matlab") + self.MATLAB_ACTUAL_DIR = osp.join(self.TEST_DIR, "actual", "matlab") + + if not osp.exists(self.MATLAB_ACTUAL_DIR): + os.mkdir(self.MATLAB_ACTUAL_DIR) + + # Generate the matlab.h file if it does not exist + template_file = osp.join(self.TEST_DIR, "..", "gtwrap", + "matlab_wrapper", "matlab_wrapper.tpl") + if not osp.exists(template_file): + with open(template_file, 'w') as tpl: + tpl.write("#include \n#include \n") + + # Create the `actual/matlab` directory + os.makedirs(self.MATLAB_ACTUAL_DIR, exist_ok=True) + + # set the log level to INFO by default + logger.remove() # remove the default sink + logger.add(sys.stderr, format="{time} {level} {message}", level="INFO") def compare_and_diff(self, file): """ @@ -109,11 +67,7 @@ def test_geometry(self): python3 wrap/matlab_wrapper.py --src wrap/tests/geometry.h --module_name geometry --out wrap/tests/actual-matlab """ - with open(osp.join(self.INTERFACE_DIR, 'geometry.i'), 'r') as f: - content = f.read() - - if not osp.exists(self.MATLAB_ACTUAL_DIR): - os.mkdir(self.MATLAB_ACTUAL_DIR) + file = osp.join(self.INTERFACE_DIR, 'geometry.i') # Create MATLAB wrapper instance wrapper = MatlabWrapper( @@ -122,24 +76,18 @@ def test_geometry(self): ignore_classes=[''], ) - cc_content = wrapper.wrap(content) + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) - self.generate_content(cc_content) + files = ['+gtsam/Point2.m', '+gtsam/Point3.m', 'geometry_wrapper.cpp'] self.assertTrue(osp.isdir(osp.join(self.MATLAB_ACTUAL_DIR, '+gtsam'))) - files = ['+gtsam/Point2.m', '+gtsam/Point3.m', 'geometry_wrapper.cpp'] - for file in files: self.compare_and_diff(file) def test_functions(self): """Test interface file with function info.""" - with open(osp.join(self.INTERFACE_DIR, 'functions.i'), 'r') as f: - content = f.read() - - if not osp.exists(self.MATLAB_ACTUAL_DIR): - os.mkdir(self.MATLAB_ACTUAL_DIR) + file = osp.join(self.INTERFACE_DIR, 'functions.i') wrapper = MatlabWrapper( module_name='functions', @@ -147,9 +95,7 @@ def test_functions(self): ignore_classes=[''], ) - cc_content = wrapper.wrap(content) - - self.generate_content(cc_content) + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ 'functions_wrapper.cpp', 'aGlobalFunction.m', 'load2D.m', @@ -163,11 +109,7 @@ def test_functions(self): def test_class(self): """Test interface file with only class info.""" - with open(osp.join(self.INTERFACE_DIR, 'class.i'), 'r') as f: - content = f.read() - - if not osp.exists(self.MATLAB_ACTUAL_DIR): - os.mkdir(self.MATLAB_ACTUAL_DIR) + file = osp.join(self.INTERFACE_DIR, 'class.i') wrapper = MatlabWrapper( module_name='class', @@ -175,9 +117,7 @@ def test_class(self): ignore_classes=[''], ) - cc_content = wrapper.wrap(content) - - self.generate_content(cc_content) + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ 'class_wrapper.cpp', 'FunDouble.m', 'FunRange.m', @@ -191,21 +131,14 @@ def test_class(self): def test_inheritance(self): """Test interface file with class inheritance definitions.""" - with open(osp.join(self.INTERFACE_DIR, 'inheritance.i'), 'r') as f: - content = f.read() - - if not osp.exists(self.MATLAB_ACTUAL_DIR): - os.mkdir(self.MATLAB_ACTUAL_DIR) + file = osp.join(self.INTERFACE_DIR, 'inheritance.i') wrapper = MatlabWrapper( module_name='inheritance', top_module_namespace=['gtsam'], ignore_classes=[''], ) - - cc_content = wrapper.wrap(content) - - self.generate_content(cc_content) + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ 'inheritance_wrapper.cpp', 'MyBase.m', 'MyTemplateMatrix.m', @@ -219,11 +152,7 @@ def test_namespaces(self): """ Test interface file with full namespace definition. """ - with open(osp.join(self.INTERFACE_DIR, 'namespaces.i'), 'r') as f: - content = f.read() - - if not osp.exists(self.MATLAB_ACTUAL_DIR): - os.mkdir(self.MATLAB_ACTUAL_DIR) + file = osp.join(self.INTERFACE_DIR, 'namespaces.i') wrapper = MatlabWrapper( module_name='namespaces', @@ -231,9 +160,7 @@ def test_namespaces(self): ignore_classes=[''], ) - cc_content = wrapper.wrap(content) - - self.generate_content(cc_content) + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ 'namespaces_wrapper.cpp', '+ns1/aGlobalFunction.m', @@ -249,21 +176,14 @@ def test_special_cases(self): """ Tests for some unique, non-trivial features. """ - with open(osp.join(self.INTERFACE_DIR, 'special_cases.i'), 'r') as f: - content = f.read() - - if not osp.exists(self.MATLAB_ACTUAL_DIR): - os.mkdir(self.MATLAB_ACTUAL_DIR) + file = osp.join(self.INTERFACE_DIR, 'special_cases.i') wrapper = MatlabWrapper( module_name='special_cases', top_module_namespace=['gtsam'], ignore_classes=[''], ) - - cc_content = wrapper.wrap(content) - - self.generate_content(cc_content) + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ 'special_cases_wrapper.cpp', @@ -274,6 +194,31 @@ def test_special_cases(self): for file in files: self.compare_and_diff(file) + def test_multiple_files(self): + """ + Test for when multiple interface files are specified. + """ + file1 = osp.join(self.INTERFACE_DIR, 'part1.i') + file2 = osp.join(self.INTERFACE_DIR, 'part2.i') + + wrapper = MatlabWrapper( + module_name='multiple_files', + top_module_namespace=['gtsam'], + ignore_classes=[''], + ) + + wrapper.wrap([file1, file2], path=self.MATLAB_ACTUAL_DIR) + + files = [ + 'multiple_files_wrapper.cpp', + '+gtsam/Class1.m', + '+gtsam/Class2.m', + '+gtsam/ClassA.m', + ] + + for file in files: + self.compare_and_diff(file) + if __name__ == '__main__': unittest.main() diff --git a/wrap/tests/test_pybind_wrapper.py b/wrap/tests/test_pybind_wrapper.py index 77c884b622..67c637d146 100644 --- a/wrap/tests/test_pybind_wrapper.py +++ b/wrap/tests/test_pybind_wrapper.py @@ -31,9 +31,9 @@ class TestWrap(unittest.TestCase): # Create the `actual/python` directory os.makedirs(PYTHON_ACTUAL_DIR, exist_ok=True) - def wrap_content(self, content, module_name, output_dir): + def wrap_content(self, sources, module_name, output_dir): """ - Common function to wrap content. + Common function to wrap content in `sources`. """ with open(osp.join(self.TEST_DIR, "pybind_wrapper.tpl")) as template_file: @@ -46,15 +46,12 @@ def wrap_content(self, content, module_name, output_dir): ignore_classes=[''], module_template=module_template) - cc_content = wrapper.wrap(content) - output = osp.join(self.TEST_DIR, output_dir, module_name + ".cpp") if not osp.exists(osp.join(self.TEST_DIR, output_dir)): os.mkdir(osp.join(self.TEST_DIR, output_dir)) - with open(output, 'w') as f: - f.write(cc_content) + wrapper.wrap(sources, output) return output @@ -76,39 +73,32 @@ def test_geometry(self): python3 ../pybind_wrapper.py --src geometry.h --module_name geometry_py --out output/geometry_py.cc """ - with open(osp.join(self.INTERFACE_DIR, 'geometry.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'geometry_py', + source = osp.join(self.INTERFACE_DIR, 'geometry.i') + output = self.wrap_content([source], 'geometry_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('geometry_pybind.cpp', output) def test_functions(self): """Test interface file with function info.""" - with open(osp.join(self.INTERFACE_DIR, 'functions.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'functions_py', + source = osp.join(self.INTERFACE_DIR, 'functions.i') + output = self.wrap_content([source], 'functions_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('functions_pybind.cpp', output) def test_class(self): """Test interface file with only class info.""" - with open(osp.join(self.INTERFACE_DIR, 'class.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'class_py', self.PYTHON_ACTUAL_DIR) + source = osp.join(self.INTERFACE_DIR, 'class.i') + output = self.wrap_content([source], 'class_py', + self.PYTHON_ACTUAL_DIR) self.compare_and_diff('class_pybind.cpp', output) def test_inheritance(self): """Test interface file with class inheritance definitions.""" - with open(osp.join(self.INTERFACE_DIR, 'inheritance.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'inheritance_py', + source = osp.join(self.INTERFACE_DIR, 'inheritance.i') + output = self.wrap_content([source], 'inheritance_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('inheritance_pybind.cpp', output) @@ -119,10 +109,8 @@ def test_namespaces(self): python3 ../pybind_wrapper.py --src namespaces.i --module_name namespaces_py --out output/namespaces_py.cpp """ - with open(osp.join(self.INTERFACE_DIR, 'namespaces.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'namespaces_py', + source = osp.join(self.INTERFACE_DIR, 'namespaces.i') + output = self.wrap_content([source], 'namespaces_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('namespaces_pybind.cpp', output) @@ -131,10 +119,8 @@ def test_operator_overload(self): """ Tests for operator overloading. """ - with open(osp.join(self.INTERFACE_DIR, 'operator.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'operator_py', + source = osp.join(self.INTERFACE_DIR, 'operator.i') + output = self.wrap_content([source], 'operator_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('operator_pybind.cpp', output) @@ -143,10 +129,8 @@ def test_special_cases(self): """ Tests for some unique, non-trivial features. """ - with open(osp.join(self.INTERFACE_DIR, 'special_cases.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'special_cases_py', + source = osp.join(self.INTERFACE_DIR, 'special_cases.i') + output = self.wrap_content([source], 'special_cases_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('special_cases_pybind.cpp', output) @@ -155,10 +139,8 @@ def test_enum(self): """ Test if enum generation is correct. """ - with open(osp.join(self.INTERFACE_DIR, 'enum.i'), 'r') as f: - content = f.read() - - output = self.wrap_content(content, 'enum_py', self.PYTHON_ACTUAL_DIR) + source = osp.join(self.INTERFACE_DIR, 'enum.i') + output = self.wrap_content([source], 'enum_py', self.PYTHON_ACTUAL_DIR) self.compare_and_diff('enum_pybind.cpp', output)