Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matlab default arguments #135

Merged
merged 7 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 59 additions & 24 deletions gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import textwrap
from functools import partial, reduce
from typing import Dict, Iterable, List, Union
import copy

import gtwrap.interface_parser as parser
from gtwrap.interface_parser.function import ArgumentList
import gtwrap.template_instantiator as instantiator
from gtwrap.matlab_wrapper.mixins import CheckMixin, FormatMixin
from gtwrap.matlab_wrapper.templates import WrapperTemplate
Expand Down Expand Up @@ -137,6 +139,37 @@ def _insert_spaces(self, x, y):
"""
return x + '\n' + ('' if y == '' else ' ') + y

@staticmethod
def _expand_default_arguments(method, save_backup=True):
gchenfc marked this conversation as resolved.
Show resolved Hide resolved
"""Recursively expand all possibilities for optional default arguments.
We create "overload" functions with fewer arguments, but since we have to "remember" what
the default arguments are for later, we make a backup.
"""
def args_copy(args):
return ArgumentList([copy.copy(arg) for arg in args.list()])
def method_copy(method):
method2 = copy.copy(method)
method2.args = args_copy(method.args)
method2.args.backup = method.args.backup
return method2
if save_backup:
method.args.backup = args_copy(method.args)
method = method_copy(method)
for arg in reversed(method.args.list()):
if arg.default is not None:
arg.default = None
methodWithArg = method_copy(method)
method.args.list().remove(arg)
return [
methodWithArg,
*MatlabWrapper._expand_default_arguments(method, save_backup=False)
]
break
assert all(arg.default is None for arg in method.args.list()), \
'In parsing method {:}: Arguments with default values cannot appear before ones ' \
'without default values.'.format(method.name)
return [method]

def _group_methods(self, methods):
"""Group overloaded methods together"""
method_map = {}
Expand All @@ -147,9 +180,9 @@ def _group_methods(self, methods):

if method_index is None:
method_map[method.name] = len(method_out)
method_out.append([method])
method_out.append(MatlabWrapper._expand_default_arguments(method))
else:
method_out[method_index].append(method)
method_out[method_index] += MatlabWrapper._expand_default_arguments(method)

return method_out

Expand Down Expand Up @@ -301,13 +334,9 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");),
((a), std::shared_ptr<Test> p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");)
"""
params = ''
body_args = ''

for arg in args.list():
if params != '':
params += ','

if self.is_ref(arg.ctype): # and not constructor:
ctype_camel = self._format_type_name(arg.ctype.typename,
separator='')
Expand Down Expand Up @@ -336,8 +365,6 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
name=arg.name,
id=arg_id)),
prefix=' ')
if call_type == "":
params += "*"

else:
body_args += textwrap.indent(textwrap.dedent('''\
Expand All @@ -347,10 +374,29 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
id=arg_id)),
prefix=' ')

params += arg.name

arg_id += 1

params = ''
explicit_arg_names = [arg.name for arg in args.list()]
# when returning the params list, we need to re-include the default args.
for arg in args.backup.list():
if params != '':
params += ','

if (arg.default is not None) and (arg.name not in explicit_arg_names):
params += arg.default
continue

if (not self.is_ref(arg.ctype)) and (self.is_shared_ptr(arg.ctype)) and (self.is_ptr(
arg.ctype)) and (arg.ctype.typename.name not in self.ignore_namespace):
if arg.ctype.is_shared_ptr:
call_type = arg.ctype.is_shared_ptr
else:
call_type = arg.ctype.is_ptr
if call_type == "":
params += "*"
params += arg.name

return params, body_args

@staticmethod
Expand Down Expand Up @@ -555,6 +601,8 @@ def wrap_class_constructors(self, namespace_name, inst_class, parent_name,
if not isinstance(ctors, Iterable):
ctors = [ctors]

ctors = sum((MatlabWrapper._expand_default_arguments(ctor) for ctor in ctors), [])

methods_wrap = textwrap.indent(textwrap.dedent("""\
methods
function obj = {class_name}(varargin)
Expand Down Expand Up @@ -674,20 +722,7 @@ def wrap_class_display(self):

def _group_class_methods(self, methods):
"""Group overloaded methods together"""
method_map = {}
method_out = []

for method in methods:
method_index = method_map.get(method.name)

if method_index is None:
method_map[method.name] = len(method_out)
method_out.append([method])
else:
# print("[_group_methods] Merging {} with {}".format(method_index, method.name))
method_out[method_index].append(method)

return method_out
return self._group_methods(methods)

@classmethod
def _format_varargout(cls, return_type, return_type_formatted):
Expand Down
2 changes: 2 additions & 0 deletions tests/actual/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
./*
!.gitignore
4 changes: 4 additions & 0 deletions tests/expected/matlab/DefaultFuncInt.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
function varargout = DefaultFuncInt(varargin)
if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric')
functions_wrapper(8, varargin{:});
elseif length(varargin) == 1 && isa(varargin{1},'numeric')
functions_wrapper(9, varargin{:});
elseif length(varargin) == 0
functions_wrapper(10, varargin{:});
else
error('Arguments do not match any overload of function DefaultFuncInt');
end
4 changes: 3 additions & 1 deletion tests/expected/matlab/DefaultFuncObj.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
function varargout = DefaultFuncObj(varargin)
if length(varargin) == 1 && isa(varargin{1},'gtsam.KeyFormatter')
functions_wrapper(10, varargin{:});
functions_wrapper(14, varargin{:});
elseif length(varargin) == 0
functions_wrapper(15, varargin{:});
else
error('Arguments do not match any overload of function DefaultFuncObj');
end
6 changes: 5 additions & 1 deletion tests/expected/matlab/DefaultFuncString.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
function varargout = DefaultFuncString(varargin)
if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'char')
functions_wrapper(9, varargin{:});
functions_wrapper(11, varargin{:});
elseif length(varargin) == 1 && isa(varargin{1},'char')
functions_wrapper(12, varargin{:});
elseif length(varargin) == 0
functions_wrapper(13, varargin{:});
else
error('Arguments do not match any overload of function DefaultFuncString');
end
6 changes: 5 additions & 1 deletion tests/expected/matlab/DefaultFuncVector.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
function varargout = DefaultFuncVector(varargin)
if length(varargin) == 2 && isa(varargin{1},'std.vectornumeric') && isa(varargin{2},'std.vectorchar')
functions_wrapper(12, varargin{:});
functions_wrapper(20, varargin{:});
elseif length(varargin) == 1 && isa(varargin{1},'std.vectornumeric')
functions_wrapper(21, varargin{:});
elseif length(varargin) == 0
functions_wrapper(22, varargin{:});
else
error('Arguments do not match any overload of function DefaultFuncVector');
end
10 changes: 8 additions & 2 deletions tests/expected/matlab/DefaultFuncZero.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
function varargout = DefaultFuncZero(varargin)
if length(varargin) == 5 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'logical') && isa(varargin{5},'logical')
functions_wrapper(11, varargin{:});
if length(varargin) == 5 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'numeric') && isa(varargin{5},'logical')
functions_wrapper(16, varargin{:});
elseif length(varargin) == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'numeric')
functions_wrapper(17, varargin{:});
elseif length(varargin) == 3 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double')
functions_wrapper(18, varargin{:});
elseif length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric')
functions_wrapper(19, varargin{:});
else
error('Arguments do not match any overload of function DefaultFuncZero');
end
4 changes: 3 additions & 1 deletion tests/expected/matlab/ForwardKinematics.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
class_wrapper(55, my_ptr);
elseif nargin == 5 && isa(varargin{1},'gtdynamics.Robot') && isa(varargin{2},'char') && isa(varargin{3},'char') && isa(varargin{4},'gtsam.Values') && isa(varargin{5},'gtsam.Pose3')
my_ptr = class_wrapper(56, varargin{1}, varargin{2}, varargin{3}, varargin{4}, varargin{5});
elseif nargin == 4 && isa(varargin{1},'gtdynamics.Robot') && isa(varargin{2},'char') && isa(varargin{3},'char') && isa(varargin{4},'gtsam.Values')
my_ptr = class_wrapper(57, varargin{1}, varargin{2}, varargin{3}, varargin{4});
else
error('Arguments do not match any overload of ForwardKinematics constructor');
end
obj.ptr_ForwardKinematics = my_ptr;
end

function delete(obj)
class_wrapper(57, obj.ptr_ForwardKinematics);
class_wrapper(58, obj.ptr_ForwardKinematics);
end

function display(obj), obj.print(''); end
Expand Down
20 changes: 16 additions & 4 deletions tests/expected/matlab/MyFactorPosePoint2.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
function obj = MyFactorPosePoint2(varargin)
if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
my_ptr = varargin{2};
class_wrapper(64, my_ptr);
class_wrapper(65, my_ptr);
elseif nargin == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'gtsam.noiseModel.Base')
my_ptr = class_wrapper(65, varargin{1}, varargin{2}, varargin{3}, varargin{4});
my_ptr = class_wrapper(66, varargin{1}, varargin{2}, varargin{3}, varargin{4});
else
error('Arguments do not match any overload of MyFactorPosePoint2 constructor');
end
obj.ptr_MyFactorPosePoint2 = my_ptr;
end

function delete(obj)
class_wrapper(66, obj.ptr_MyFactorPosePoint2);
class_wrapper(67, obj.ptr_MyFactorPosePoint2);
end

function display(obj), obj.print(''); end
Expand All @@ -36,7 +36,19 @@ function delete(obj)
% PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter')
class_wrapper(67, this, varargin{:});
class_wrapper(68, this, varargin{:});
return
end
% PRINT usage: print(string s) : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 1 && isa(varargin{1},'char')
class_wrapper(69, this, varargin{:});
return
end
% PRINT usage: print() : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 0
class_wrapper(70, this, varargin{:});
return
end
error('Arguments do not match any overload of function MyFactorPosePoint2.print');
Expand Down
2 changes: 1 addition & 1 deletion tests/expected/matlab/TemplatedFunctionRot3.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function varargout = TemplatedFunctionRot3(varargin)
if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3')
functions_wrapper(14, varargin{:});
functions_wrapper(25, varargin{:});
else
error('Arguments do not match any overload of function TemplatedFunctionRot3');
end
Loading