Skip to content

Commit

Permalink
Squashed 'wrap/' changes from 0ab10c359..248971868
Browse files Browse the repository at this point in the history
248971868 Merge pull request #132 from borglab/fix/matlab-wrapper
157fad9e5 fix where generation of wrapper files takes place
f2ad4e475 update tests and fixtures
65e230b0d fixes to get the matlab wrapper working

git-subtree-dir: wrap
git-subtree-split: 24897186873c92a32707ca8718f7e7b7dbffc589
  • Loading branch information
varunagrawal committed Dec 6, 2021
1 parent c195bb5 commit aa693b2
Show file tree
Hide file tree
Showing 36 changed files with 617 additions and 239 deletions.
25 changes: 12 additions & 13 deletions cmake/MatlabWrap.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ macro(find_and_configure_matlab)
endmacro()

# Consistent and user-friendly wrap function
function(matlab_wrap interfaceHeader linkLibraries
function(matlab_wrap interfaceHeader moduleName linkLibraries
extraIncludeDirs extraMexFlags ignore_classes)
find_and_configure_matlab()
wrap_and_install_library("${interfaceHeader}" "${linkLibraries}"
wrap_and_install_library("${interfaceHeader}" "${moduleName}" "${linkLibraries}"
"${extraIncludeDirs}" "${extraMexFlags}"
"${ignore_classes}")
endfunction()
Expand All @@ -77,6 +77,7 @@ endfunction()
# Arguments:
#
# interfaceHeader: The relative path to the wrapper interface definition file.
# moduleName: The name of the wrapped module, e.g. gtsam
# linkLibraries: Any *additional* libraries to link. Your project library
# (e.g. `lba`), libraries it depends on, and any necessary MATLAB libraries will
# be linked automatically. So normally, leave this empty.
Expand All @@ -85,15 +86,15 @@ endfunction()
# extraMexFlags: Any *additional* flags to pass to the compiler when building
# the wrap code. Normally, leave this empty.
# ignore_classes: List of classes to ignore in the wrapping.
function(wrap_and_install_library interfaceHeader linkLibraries
function(wrap_and_install_library interfaceHeader moduleName linkLibraries
extraIncludeDirs extraMexFlags ignore_classes)
wrap_library_internal("${interfaceHeader}" "${linkLibraries}"
wrap_library_internal("${interfaceHeader}" "${moduleName}" "${linkLibraries}"
"${extraIncludeDirs}" "${mexFlags}")
install_wrapped_library_internal("${interfaceHeader}")
install_wrapped_library_internal("${moduleName}")
endfunction()

# Internal function that wraps a library and compiles the wrapper
function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
function(wrap_library_internal interfaceHeader moduleName linkLibraries extraIncludeDirs
extraMexFlags)
if(UNIX AND NOT APPLE)
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
Expand All @@ -120,7 +121,6 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
# Extract module name from interface header file name
get_filename_component(interfaceHeader "${interfaceHeader}" ABSOLUTE)
get_filename_component(modulePath "${interfaceHeader}" PATH)
get_filename_component(moduleName "${interfaceHeader}" NAME_WE)

# Paths for generated files
set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}")
Expand All @@ -136,8 +136,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
# explicit link libraries list so that the next block of code can unpack any
# static libraries
set(automaticDependencies "")
foreach(lib ${moduleName} ${linkLibraries})
# message("MODULE NAME: ${moduleName}")
foreach(lib ${module} ${linkLibraries})
if(TARGET "${lib}")
get_target_property(dependentLibraries ${lib} INTERFACE_LINK_LIBRARIES)
# message("DEPENDENT LIBRARIES: ${dependentLibraries}")
Expand Down Expand Up @@ -176,7 +175,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
set(otherLibraryTargets "")
set(otherLibraryNontargets "")
set(otherSourcesAndObjects "")
foreach(lib ${moduleName} ${linkLibraries} ${automaticDependencies})
foreach(lib ${module} ${linkLibraries} ${automaticDependencies})
if(TARGET "${lib}")
if(WRAP_MEX_BUILD_STATIC_MODULE)
get_target_property(target_sources ${lib} SOURCES)
Expand Down Expand Up @@ -250,7 +249,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
COMMAND
${CMAKE_COMMAND} -E env
"PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}"
${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src ${interfaceHeader}
${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src "${interfaceHeader}"
--module_name ${moduleName} --out ${generated_files_path}
--top_module_namespaces ${moduleName} --ignore ${ignore_classes}
VERBATIM
Expand Down Expand Up @@ -324,8 +323,8 @@ endfunction()

# Internal function that installs a wrap toolbox
function(install_wrapped_library_internal interfaceHeader)
get_filename_component(moduleName "${interfaceHeader}" NAME_WE)
set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}")
get_filename_component(module "${interfaceHeader}" NAME_WE)
set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${module}")

# NOTE: only installs .m and mex binary files (not .cpp) - the trailing slash
# on the directory name here prevents creating the top-level module name
Expand Down
8 changes: 6 additions & 2 deletions gtwrap/interface_parser/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def __init__(self,
self.name = t[-1] # the name is the last element in this list
self.namespaces = t[:-1]

# If the first namespace is empty string, just get rid of it.
if self.namespaces and self.namespaces[0] == '':
self.namespaces.pop(0)

if instantiations:
if isinstance(instantiations, Sequence):
self.instantiations = instantiations # type: ignore
Expand Down Expand Up @@ -92,8 +96,8 @@ def to_cpp(self) -> str:
else:
cpp_name = self.name
return '{}{}{}'.format(
"::".join(self.namespaces[idx:]),
"::" if self.namespaces[idx:] else "",
"::".join(self.namespaces),
"::" if self.namespaces else "",
cpp_name,
)

Expand Down
7 changes: 3 additions & 4 deletions gtwrap/matlab_wrapper/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _format_type_name(self,
elif is_method:
formatted_type_name += self.data_type_param.get(name) or name
else:
formatted_type_name += name
formatted_type_name += str(name)

if separator == "::": # C++
templates = []
Expand Down Expand Up @@ -192,10 +192,9 @@ def _format_static_method(self,
method = ''

if isinstance(static_method, parser.StaticMethod):
method += "".join([separator + x for x in static_method.parent.namespaces()]) + \
separator + static_method.parent.name + separator
method += static_method.parent.to_cpp() + separator

return method[2 * len(separator):]
return method

def _format_global_function(self,
function: Union[parser.GlobalFunction, Any],
Expand Down
50 changes: 27 additions & 23 deletions gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,18 @@ def _wrap_list_variable_arguments(self, args):

return var_list_wrap

def _wrap_method_check_statement(self, args):
def _wrap_method_check_statement(self, args: parser.ArgumentList):
"""
Wrap the given arguments into either just a varargout call or a
call in an if statement that checks if the parameters are accurate.
TODO Update this method so that default arguments are supported.
"""
check_statement = ''
arg_id = 1

if check_statement == '':
check_statement = \
'if length(varargin) == {param_count}'.format(
param_count=len(args.list()))
param_count = len(args)
check_statement = 'if length(varargin) == {param_count}'.format(
param_count=param_count)

for _, arg in enumerate(args.list()):
name = arg.ctype.typename.name
Expand Down Expand Up @@ -809,7 +809,7 @@ def wrap_static_methods(self, namespace_name, instantiated_class,

for static_method in static_methods:
format_name = list(static_method[0].name)
format_name[0] = format_name[0].upper()
format_name[0] = format_name[0]

if static_method[0].name in self.ignore_methods:
continue
Expand Down Expand Up @@ -850,12 +850,13 @@ def wrap_static_methods(self, namespace_name, instantiated_class,
wrapper=self._wrapper_name(),
id=self._update_wrapper_id(
(namespace_name, instantiated_class,
static_overload.name, static_overload)),
static_overload.name, static_overload)),
class_name=instantiated_class.name,
end_statement=end_statement),
prefix=' ')
prefix=' ')

#TODO Figure out what is static_overload doing here.
# If the arguments don't match any of the checks above,
# throw an error with the class and method name.
method_text += textwrap.indent(textwrap.dedent("""\
error('Arguments do not match any overload of function {class_name}.{method_name}');
""".format(class_name=class_name,
Expand Down Expand Up @@ -1081,7 +1082,6 @@ def wrap_collector_function_return(self, method):
obj_start = ''

if isinstance(method, instantiator.InstantiatedMethod):
# method_name = method.original.name
method_name = method.to_cpp()
obj_start = 'obj->'

Expand All @@ -1090,6 +1090,10 @@ def wrap_collector_function_return(self, method):
# self._format_type_name(method.instantiations))
method = method.to_cpp()

elif isinstance(method, instantiator.InstantiatedStaticMethod):
method_name = self._format_static_method(method, '::')
method_name += method.original.name

elif isinstance(method, parser.GlobalFunction):
method_name = self._format_global_function(method, '::')
method_name += method.name
Expand Down Expand Up @@ -1250,7 +1254,7 @@ def generate_collector_function(self, func_id):
method_name = ''

if is_static_method:
method_name = self._format_static_method(extra) + '.'
method_name = self._format_static_method(extra, '.')

method_name += extra.name

Expand Down Expand Up @@ -1567,23 +1571,23 @@ def generate_content(self, cc_content, path):

def wrap(self, files, path):
"""High level function to wrap the project."""
content = ""
modules = {}
for file in files:
with open(file, 'r') as f:
content = f.read()
content += f.read()

# Parse the contents of the interface file
parsed_result = parser.Module.parseString(content)
# print(parsed_result)
# Parse the contents of the interface file
parsed_result = parser.Module.parseString(content)

# Instantiate the module
module = instantiator.instantiate_namespace(parsed_result)
# Instantiate the module
module = instantiator.instantiate_namespace(parsed_result)

if module.name in modules:
modules[module.
name].content[0].content += module.content[0].content
else:
modules[module.name] = module
if module.name in modules:
modules[
module.name].content[0].content += module.content[0].content
else:
modules[module.name] = module

for module in modules.values():
# Wrap the full namespace
Expand Down
23 changes: 15 additions & 8 deletions gtwrap/template_instantiator/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ def instantiate_type(
# make a deep copy so that there is no overwriting of original template params
ctype = deepcopy(ctype)

# Check if the return type has template parameters
# Check if the return type has template parameters as the typename's name
if ctype.typename.instantiations:
for idx, instantiation in enumerate(ctype.typename.instantiations):
if instantiation.name in template_typenames:
template_idx = template_typenames.index(instantiation.name)
ctype.typename.instantiations[
idx] = instantiations[ # type: ignore
template_idx]
ctype.typename.instantiations[idx].name =\
instantiations[template_idx]

return ctype

str_arg_typename = str(ctype.typename)

Expand Down Expand Up @@ -125,9 +123,18 @@ def instantiate_type(

# Case when 'This' is present in the type namespace, e.g `This::Subclass`.
elif 'This' in str_arg_typename:
# Simply get the index of `This` in the namespace and replace it with the instantiated name.
namespace_idx = ctype.typename.namespaces.index('This')
ctype.typename.namespaces[namespace_idx] = cpp_typename.name
# Check if `This` is in the namespaces
if 'This' in ctype.typename.namespaces:
# Simply get the index of `This` in the namespace and
# replace it with the instantiated name.
namespace_idx = ctype.typename.namespaces.index('This')
ctype.typename.namespaces[namespace_idx] = cpp_typename.name
# Else check if it is in the template namespace, e.g vector<This::Value>
else:
for idx, instantiation in enumerate(ctype.typename.instantiations):
if 'This' in instantiation.namespaces:
ctype.typename.instantiations[idx].namespaces = \
cpp_typename.namespaces + [cpp_typename.name]
return ctype

else:
Expand Down
31 changes: 31 additions & 0 deletions tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
%class GeneralSFMFactorCal3Bundler, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
classdef GeneralSFMFactorCal3Bundler < handle
properties
ptr_gtsamGeneralSFMFactorCal3Bundler = 0
end
methods
function obj = GeneralSFMFactorCal3Bundler(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
special_cases_wrapper(7, my_ptr);
else
error('Arguments do not match any overload of gtsam.GeneralSFMFactorCal3Bundler constructor');
end
obj.ptr_gtsamGeneralSFMFactorCal3Bundler = my_ptr;
end

function delete(obj)
special_cases_wrapper(8, obj.ptr_gtsamGeneralSFMFactorCal3Bundler);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
end
end
2 changes: 1 addition & 1 deletion tests/expected/matlab/+gtsam/Point3.m
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function delete(obj)
error('Arguments do not match any overload of function Point3.StaticFunctionRet');
end

function varargout = StaticFunction(varargin)
function varargout = staticFunction(varargin)
% STATICFUNCTION usage: staticFunction() : returns double
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 0
Expand Down
31 changes: 31 additions & 0 deletions tests/expected/matlab/+gtsam/SfmTrack.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
%class SfmTrack, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
classdef SfmTrack < handle
properties
ptr_gtsamSfmTrack = 0
end
methods
function obj = SfmTrack(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
special_cases_wrapper(3, my_ptr);
else
error('Arguments do not match any overload of gtsam.SfmTrack constructor');
end
obj.ptr_gtsamSfmTrack = my_ptr;
end

function delete(obj)
special_cases_wrapper(4, obj.ptr_gtsamSfmTrack);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
end
end
Loading

0 comments on commit aa693b2

Please sign in to comment.