Skip to content

Commit

Permalink
Allow keyword args in long_arg options
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Jul 20, 2017
1 parent 4af40e3 commit e708de3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,7 +2296,7 @@ def test_RNGStateAliasing(self):
target_value = torch.rand(1000)
# Dramatically alter the internal state of the main generator
_ = torch.rand(100000)
forked_value = torch.rand(gen, 1000)
forked_value = torch.rand(1000, generator=gen)
self.assertEqual(target_value, forked_value, 0, "RNG has not forked correctly.")

def test_boxMullerState(self):
Expand Down
10 changes: 4 additions & 6 deletions tools/cwrap/plugins/KwargsPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,23 @@ def process_declarations(self, declarations):
for option in declaration['options']:
offset = 0
for arg in option['arguments']:
if arg.get('kwarg_only') and not arg.get('ignore_check', False):
offset += 1
else:
arg['kwarg_offset'] = offset
if arg.get('kwarg_only'):
arg['no_idx'] = True
return declarations

def get_arg_accessor(self, arg, option):
if arg.get('no_kwargs'):
return
if arg.get('kwarg_only'):
return self.KWARG_ONLY_ACCESSOR_TEMPLATE.substitute(name=arg['name'])
return self.ACCESSOR_TEMPLATE.substitute(idx=arg['idx'] - arg['kwarg_offset'], name=arg['name'])
return self.ACCESSOR_TEMPLATE.substitute(idx=arg['idx'], name=arg['name'])

def process_single_check(self, code, arg, arg_accessor):
if arg.get('no_kwargs'):
return code
if arg.get('kwarg_only'):
return self.KWARG_ONLY_CHECK_TEMPLATE.substitute(name=arg['name'], code=code)
return self.CHECK_TEMPLATE.substitute(idx=arg['idx'] - arg['kwarg_offset'], name=arg['name'], code=code)
return self.CHECK_TEMPLATE.substitute(idx=arg['idx'], name=arg['name'], code=code)

def process_wrapper(self, code, declaration):
if declaration.get('no_kwargs'):
Expand Down
9 changes: 7 additions & 2 deletions tools/cwrap/plugins/THPPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def expand_composite_type(p, t):
declaration['variables'] += ['PyObject *__out;']
self.generate_out_options(declaration)
if has_long_args(declaration):
declaration['no_kwargs'] = True
for option in declaration['options']:
for arg in option['arguments']:
if arg.get('long_args', False):
arg['no_kwargs'] = True
for option in declaration['options']:
option['cname'] = 'TH{}Tensor_({})'.format(
'S' if option.get('sparse', False) else '', option['cname'])
Expand Down Expand Up @@ -554,7 +557,9 @@ def process_all_checks(self, code, option):

if any(arg.get('long_args', False) for arg in option['arguments']):
code = code.replace('__argcount ==', '__argcount >=')
expected = str(int(option.get('output_provided', False)))
expected = str(int(option.get('output_provided', False)) +
sum(not arg.get('no_kwargs', False) and not arg.get('ignore_check', False)
for arg in option['arguments']))
code = '__dictcount == ' + expected + ' &&\n ' + code

return code
Expand Down

0 comments on commit e708de3

Please sign in to comment.