From e708de37ccbd15f248d26991443abbd40c8cde48 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sat, 15 Jul 2017 11:25:56 -0700 Subject: [PATCH] Allow keyword args in long_arg options --- test/test_torch.py | 2 +- tools/cwrap/plugins/KwargsPlugin.py | 10 ++++------ tools/cwrap/plugins/THPPlugin.py | 9 +++++++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 966de659e52fad..4d35d7209a5bf1 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2296,7 +2296,7 @@ def test_RNGStateAliasing(self): target_value = torch.rand(1000) # Dramatically alter the internal state of the main generator _ = torch.rand(100000) - forked_value = torch.rand(gen, 1000) + forked_value = torch.rand(1000, generator=gen) self.assertEqual(target_value, forked_value, 0, "RNG has not forked correctly.") def test_boxMullerState(self): diff --git a/tools/cwrap/plugins/KwargsPlugin.py b/tools/cwrap/plugins/KwargsPlugin.py index 9542f5ca6fab39..59d6e6a0451985 100644 --- a/tools/cwrap/plugins/KwargsPlugin.py +++ b/tools/cwrap/plugins/KwargsPlugin.py @@ -30,10 +30,8 @@ def process_declarations(self, declarations): for option in declaration['options']: offset = 0 for arg in option['arguments']: - if arg.get('kwarg_only') and not arg.get('ignore_check', False): - offset += 1 - else: - arg['kwarg_offset'] = offset + if arg.get('kwarg_only'): + arg['no_idx'] = True return declarations def get_arg_accessor(self, arg, option): @@ -41,14 +39,14 @@ def get_arg_accessor(self, arg, option): return if arg.get('kwarg_only'): return self.KWARG_ONLY_ACCESSOR_TEMPLATE.substitute(name=arg['name']) - return self.ACCESSOR_TEMPLATE.substitute(idx=arg['idx'] - arg['kwarg_offset'], name=arg['name']) + return self.ACCESSOR_TEMPLATE.substitute(idx=arg['idx'], name=arg['name']) def process_single_check(self, code, arg, arg_accessor): if arg.get('no_kwargs'): return code if arg.get('kwarg_only'): return self.KWARG_ONLY_CHECK_TEMPLATE.substitute(name=arg['name'], code=code) - return self.CHECK_TEMPLATE.substitute(idx=arg['idx'] - arg['kwarg_offset'], name=arg['name'], code=code) + return self.CHECK_TEMPLATE.substitute(idx=arg['idx'], name=arg['name'], code=code) def process_wrapper(self, code, declaration): if declaration.get('no_kwargs'): diff --git a/tools/cwrap/plugins/THPPlugin.py b/tools/cwrap/plugins/THPPlugin.py index a528de8eee9d45..09fa328845a104 100644 --- a/tools/cwrap/plugins/THPPlugin.py +++ b/tools/cwrap/plugins/THPPlugin.py @@ -439,7 +439,10 @@ def expand_composite_type(p, t): declaration['variables'] += ['PyObject *__out;'] self.generate_out_options(declaration) if has_long_args(declaration): - declaration['no_kwargs'] = True + for option in declaration['options']: + for arg in option['arguments']: + if arg.get('long_args', False): + arg['no_kwargs'] = True for option in declaration['options']: option['cname'] = 'TH{}Tensor_({})'.format( 'S' if option.get('sparse', False) else '', option['cname']) @@ -554,7 +557,9 @@ def process_all_checks(self, code, option): if any(arg.get('long_args', False) for arg in option['arguments']): code = code.replace('__argcount ==', '__argcount >=') - expected = str(int(option.get('output_provided', False))) + expected = str(int(option.get('output_provided', False)) + + sum(not arg.get('no_kwargs', False) and not arg.get('ignore_check', False) + for arg in option['arguments'])) code = '__dictcount == ' + expected + ' &&\n ' + code return code