Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Matlab Wrapper #953

Merged
merged 5 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Squashed 'wrap/' changes from 0ab10c359..248971868
248971868 Merge pull request #132 from borglab/fix/matlab-wrapper
157fad9e5 fix where generation of wrapper files takes place
f2ad4e475 update tests and fixtures
65e230b0d fixes to get the matlab wrapper working

git-subtree-dir: wrap
git-subtree-split: 24897186873c92a32707ca8718f7e7b7dbffc589
  • Loading branch information
varunagrawal committed Dec 6, 2021
commit aa693b2e8f88d54e2dab1b40ef557525a155bb1c
25 changes: 12 additions & 13 deletions cmake/MatlabWrap.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ macro(find_and_configure_matlab)
endmacro()

# Consistent and user-friendly wrap function
function(matlab_wrap interfaceHeader linkLibraries
function(matlab_wrap interfaceHeader moduleName linkLibraries
extraIncludeDirs extraMexFlags ignore_classes)
find_and_configure_matlab()
wrap_and_install_library("${interfaceHeader}" "${linkLibraries}"
wrap_and_install_library("${interfaceHeader}" "${moduleName}" "${linkLibraries}"
"${extraIncludeDirs}" "${extraMexFlags}"
"${ignore_classes}")
endfunction()
Expand All @@ -77,6 +77,7 @@ endfunction()
# Arguments:
#
# interfaceHeader: The relative path to the wrapper interface definition file.
# moduleName: The name of the wrapped module, e.g. gtsam
# linkLibraries: Any *additional* libraries to link. Your project library
# (e.g. `lba`), libraries it depends on, and any necessary MATLAB libraries will
# be linked automatically. So normally, leave this empty.
Expand All @@ -85,15 +86,15 @@ endfunction()
# extraMexFlags: Any *additional* flags to pass to the compiler when building
# the wrap code. Normally, leave this empty.
# ignore_classes: List of classes to ignore in the wrapping.
function(wrap_and_install_library interfaceHeader linkLibraries
function(wrap_and_install_library interfaceHeader moduleName linkLibraries
extraIncludeDirs extraMexFlags ignore_classes)
wrap_library_internal("${interfaceHeader}" "${linkLibraries}"
wrap_library_internal("${interfaceHeader}" "${moduleName}" "${linkLibraries}"
"${extraIncludeDirs}" "${mexFlags}")
install_wrapped_library_internal("${interfaceHeader}")
install_wrapped_library_internal("${moduleName}")
endfunction()

# Internal function that wraps a library and compiles the wrapper
function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
function(wrap_library_internal interfaceHeader moduleName linkLibraries extraIncludeDirs
extraMexFlags)
if(UNIX AND NOT APPLE)
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
Expand All @@ -120,7 +121,6 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
# Extract module name from interface header file name
get_filename_component(interfaceHeader "${interfaceHeader}" ABSOLUTE)
get_filename_component(modulePath "${interfaceHeader}" PATH)
get_filename_component(moduleName "${interfaceHeader}" NAME_WE)

# Paths for generated files
set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}")
Expand All @@ -136,8 +136,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
# explicit link libraries list so that the next block of code can unpack any
# static libraries
set(automaticDependencies "")
foreach(lib ${moduleName} ${linkLibraries})
# message("MODULE NAME: ${moduleName}")
foreach(lib ${module} ${linkLibraries})
if(TARGET "${lib}")
get_target_property(dependentLibraries ${lib} INTERFACE_LINK_LIBRARIES)
# message("DEPENDENT LIBRARIES: ${dependentLibraries}")
Expand Down Expand Up @@ -176,7 +175,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
set(otherLibraryTargets "")
set(otherLibraryNontargets "")
set(otherSourcesAndObjects "")
foreach(lib ${moduleName} ${linkLibraries} ${automaticDependencies})
foreach(lib ${module} ${linkLibraries} ${automaticDependencies})
if(TARGET "${lib}")
if(WRAP_MEX_BUILD_STATIC_MODULE)
get_target_property(target_sources ${lib} SOURCES)
Expand Down Expand Up @@ -250,7 +249,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs
COMMAND
${CMAKE_COMMAND} -E env
"PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}"
${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src ${interfaceHeader}
${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src "${interfaceHeader}"
--module_name ${moduleName} --out ${generated_files_path}
--top_module_namespaces ${moduleName} --ignore ${ignore_classes}
VERBATIM
Expand Down Expand Up @@ -324,8 +323,8 @@ endfunction()

# Internal function that installs a wrap toolbox
function(install_wrapped_library_internal interfaceHeader)
get_filename_component(moduleName "${interfaceHeader}" NAME_WE)
set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}")
get_filename_component(module "${interfaceHeader}" NAME_WE)
set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${module}")

# NOTE: only installs .m and mex binary files (not .cpp) - the trailing slash
# on the directory name here prevents creating the top-level module name
Expand Down
8 changes: 6 additions & 2 deletions gtwrap/interface_parser/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def __init__(self,
self.name = t[-1] # the name is the last element in this list
self.namespaces = t[:-1]

# If the first namespace is empty string, just get rid of it.
if self.namespaces and self.namespaces[0] == '':
self.namespaces.pop(0)

if instantiations:
if isinstance(instantiations, Sequence):
self.instantiations = instantiations # type: ignore
Expand Down Expand Up @@ -92,8 +96,8 @@ def to_cpp(self) -> str:
else:
cpp_name = self.name
return '{}{}{}'.format(
"::".join(self.namespaces[idx:]),
"::" if self.namespaces[idx:] else "",
"::".join(self.namespaces),
"::" if self.namespaces else "",
cpp_name,
)

Expand Down
7 changes: 3 additions & 4 deletions gtwrap/matlab_wrapper/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _format_type_name(self,
elif is_method:
formatted_type_name += self.data_type_param.get(name) or name
else:
formatted_type_name += name
formatted_type_name += str(name)

if separator == "::": # C++
templates = []
Expand Down Expand Up @@ -192,10 +192,9 @@ def _format_static_method(self,
method = ''

if isinstance(static_method, parser.StaticMethod):
method += "".join([separator + x for x in static_method.parent.namespaces()]) + \
separator + static_method.parent.name + separator
method += static_method.parent.to_cpp() + separator

return method[2 * len(separator):]
return method

def _format_global_function(self,
function: Union[parser.GlobalFunction, Any],
Expand Down
50 changes: 27 additions & 23 deletions gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,18 @@ def _wrap_list_variable_arguments(self, args):

return var_list_wrap

def _wrap_method_check_statement(self, args):
def _wrap_method_check_statement(self, args: parser.ArgumentList):
"""
Wrap the given arguments into either just a varargout call or a
call in an if statement that checks if the parameters are accurate.

TODO Update this method so that default arguments are supported.
"""
check_statement = ''
arg_id = 1

if check_statement == '':
check_statement = \
'if length(varargin) == {param_count}'.format(
param_count=len(args.list()))
param_count = len(args)
check_statement = 'if length(varargin) == {param_count}'.format(
param_count=param_count)

for _, arg in enumerate(args.list()):
name = arg.ctype.typename.name
Expand Down Expand Up @@ -809,7 +809,7 @@ def wrap_static_methods(self, namespace_name, instantiated_class,

for static_method in static_methods:
format_name = list(static_method[0].name)
format_name[0] = format_name[0].upper()
format_name[0] = format_name[0]

if static_method[0].name in self.ignore_methods:
continue
Expand Down Expand Up @@ -850,12 +850,13 @@ def wrap_static_methods(self, namespace_name, instantiated_class,
wrapper=self._wrapper_name(),
id=self._update_wrapper_id(
(namespace_name, instantiated_class,
static_overload.name, static_overload)),
static_overload.name, static_overload)),
class_name=instantiated_class.name,
end_statement=end_statement),
prefix=' ')
prefix=' ')

#TODO Figure out what is static_overload doing here.
# If the arguments don't match any of the checks above,
# throw an error with the class and method name.
method_text += textwrap.indent(textwrap.dedent("""\
error('Arguments do not match any overload of function {class_name}.{method_name}');
""".format(class_name=class_name,
Expand Down Expand Up @@ -1081,7 +1082,6 @@ def wrap_collector_function_return(self, method):
obj_start = ''

if isinstance(method, instantiator.InstantiatedMethod):
# method_name = method.original.name
method_name = method.to_cpp()
obj_start = 'obj->'

Expand All @@ -1090,6 +1090,10 @@ def wrap_collector_function_return(self, method):
# self._format_type_name(method.instantiations))
method = method.to_cpp()

elif isinstance(method, instantiator.InstantiatedStaticMethod):
method_name = self._format_static_method(method, '::')
method_name += method.original.name

elif isinstance(method, parser.GlobalFunction):
method_name = self._format_global_function(method, '::')
method_name += method.name
Expand Down Expand Up @@ -1250,7 +1254,7 @@ def generate_collector_function(self, func_id):
method_name = ''

if is_static_method:
method_name = self._format_static_method(extra) + '.'
method_name = self._format_static_method(extra, '.')

method_name += extra.name

Expand Down Expand Up @@ -1567,23 +1571,23 @@ def generate_content(self, cc_content, path):

def wrap(self, files, path):
"""High level function to wrap the project."""
content = ""
modules = {}
for file in files:
with open(file, 'r') as f:
content = f.read()
content += f.read()

# Parse the contents of the interface file
parsed_result = parser.Module.parseString(content)
# print(parsed_result)
# Parse the contents of the interface file
parsed_result = parser.Module.parseString(content)

# Instantiate the module
module = instantiator.instantiate_namespace(parsed_result)
# Instantiate the module
module = instantiator.instantiate_namespace(parsed_result)

if module.name in modules:
modules[module.
name].content[0].content += module.content[0].content
else:
modules[module.name] = module
if module.name in modules:
modules[
module.name].content[0].content += module.content[0].content
else:
modules[module.name] = module

for module in modules.values():
# Wrap the full namespace
Expand Down
23 changes: 15 additions & 8 deletions gtwrap/template_instantiator/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ def instantiate_type(
# make a deep copy so that there is no overwriting of original template params
ctype = deepcopy(ctype)

# Check if the return type has template parameters
# Check if the return type has template parameters as the typename's name
if ctype.typename.instantiations:
for idx, instantiation in enumerate(ctype.typename.instantiations):
if instantiation.name in template_typenames:
template_idx = template_typenames.index(instantiation.name)
ctype.typename.instantiations[
idx] = instantiations[ # type: ignore
template_idx]
ctype.typename.instantiations[idx].name =\
instantiations[template_idx]

return ctype

str_arg_typename = str(ctype.typename)

Expand Down Expand Up @@ -125,9 +123,18 @@ def instantiate_type(

# Case when 'This' is present in the type namespace, e.g `This::Subclass`.
elif 'This' in str_arg_typename:
# Simply get the index of `This` in the namespace and replace it with the instantiated name.
namespace_idx = ctype.typename.namespaces.index('This')
ctype.typename.namespaces[namespace_idx] = cpp_typename.name
# Check if `This` is in the namespaces
if 'This' in ctype.typename.namespaces:
# Simply get the index of `This` in the namespace and
# replace it with the instantiated name.
namespace_idx = ctype.typename.namespaces.index('This')
ctype.typename.namespaces[namespace_idx] = cpp_typename.name
# Else check if it is in the template namespace, e.g vector<This::Value>
else:
for idx, instantiation in enumerate(ctype.typename.instantiations):
if 'This' in instantiation.namespaces:
ctype.typename.instantiations[idx].namespaces = \
cpp_typename.namespaces + [cpp_typename.name]
return ctype

else:
Expand Down
31 changes: 31 additions & 0 deletions tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
%class GeneralSFMFactorCal3Bundler, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
classdef GeneralSFMFactorCal3Bundler < handle
properties
ptr_gtsamGeneralSFMFactorCal3Bundler = 0
end
methods
function obj = GeneralSFMFactorCal3Bundler(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
special_cases_wrapper(7, my_ptr);
else
error('Arguments do not match any overload of gtsam.GeneralSFMFactorCal3Bundler constructor');
end
obj.ptr_gtsamGeneralSFMFactorCal3Bundler = my_ptr;
end

function delete(obj)
special_cases_wrapper(8, obj.ptr_gtsamGeneralSFMFactorCal3Bundler);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
end
end
2 changes: 1 addition & 1 deletion tests/expected/matlab/+gtsam/Point3.m
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function delete(obj)
error('Arguments do not match any overload of function Point3.StaticFunctionRet');
end

function varargout = StaticFunction(varargin)
function varargout = staticFunction(varargin)
% STATICFUNCTION usage: staticFunction() : returns double
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 0
Expand Down
31 changes: 31 additions & 0 deletions tests/expected/matlab/+gtsam/SfmTrack.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
%class SfmTrack, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
classdef SfmTrack < handle
properties
ptr_gtsamSfmTrack = 0
end
methods
function obj = SfmTrack(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
special_cases_wrapper(3, my_ptr);
else
error('Arguments do not match any overload of gtsam.SfmTrack constructor');
end
obj.ptr_gtsamSfmTrack = my_ptr;
end

function delete(obj)
special_cases_wrapper(4, obj.ptr_gtsamSfmTrack);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
end
end
Loading