diff --git a/mypy-strict.ini b/mypy-strict.ini index 42fc73abf1ccc7..7cc6fff835773f 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -31,9 +31,11 @@ strict_equality = True files = tools/codegen/gen.py, tools/autograd/gen_annotated_fn_args.py, + tools/autograd/gen_autograd.py, tools/autograd/gen_python_functions.py, tools/autograd/gen_trace_type.py, tools/autograd/gen_variable_factories.py, + tools/autograd/gen_variable_type.py, tools/autograd/load_derivatives.py, torch/utils/benchmark/utils/common.py, torch/utils/benchmark/utils/timer.py, diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 88c00e0ba71af2..b930aca504df6f 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -23,9 +23,6 @@ import argparse import os -import yaml -import re -from .utils import YamlLoader, op_name_with_overload from tools.codegen.selective_build.selector import SelectiveBuilder # See NOTE [ Autograd View Variables ] in variable.h for details. @@ -89,84 +86,14 @@ 'tensor_split', 'swapdims', 'swapaxes' }) -def format_return_type(returns): - if len(returns) == 0: - return 'void' - elif len(returns) == 1: - return returns[0]['type'] - else: - return_types = [r['type'] for r in returns] - return 'std::tuple<{}>'.format(','.join(return_types)) - - -def get_simple_type(arg): - simple_type = arg['type'] - simple_type = simple_type.replace(' &', '').replace('const ', '') - simple_type = simple_type.replace('Generator *', 'Generator') - - opt_match = re.match(r'c10::optional<(.+)>', simple_type) - if opt_match: - simple_type = '{}?'.format(opt_match.group(1)) - return simple_type - -def has_tensoroptions_argument(declaration): - for argument in declaration['arguments']: - if 'TensorOptions' == argument['dynamic_type']: - return True - return False - - -def load_aten_declarations(path): - with open(path, 'r') as f: - declarations = yaml.load(f, Loader=YamlLoader) - - # enrich declarations with additional information - selected_declarations = [] - for declaration in declarations: - if declaration.get('deprecated'): - continue - - for arg in declaration['arguments']: - arg['simple_type'] = get_simple_type(arg) - for arg in declaration['schema_order_arguments']: - arg['simple_type'] = get_simple_type(arg) - for ret in declaration['returns']: - ret['simple_type'] = get_simple_type(ret) - - declaration['formals'] = [arg['type'] + ' ' + arg['name'] - for arg in declaration['arguments']] - declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name'] - for arg in declaration['schema_order_arguments']] - declaration['args'] = [arg['name'] for arg in declaration['arguments']] - declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']] - declaration['api_name'] = declaration['name'] - if declaration.get('overload_name'): - declaration['type_wrapper_name'] = "{}_{}".format( - declaration['name'], declaration['overload_name']) - else: - declaration['type_wrapper_name'] = declaration['name'] - declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0] - declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1] - declaration['return_type'] = format_return_type(declaration['returns']) - - declaration['base_name'] = declaration['name'] - selected_declarations.append(declaration) - - return selected_declarations - - -def gen_autograd(aten_path, native_functions_path, out, autograd_dir, operator_selector: SelectiveBuilder, disable_autograd=False): - full_aten_decls = load_aten_declarations(aten_path) - - def filter_decls(aten_decls, operator_selector): - def is_operator_selected_for_training(decl): - op_name = op_name_with_overload(decl) - return operator_selector.is_operator_selected_for_training(op_name) - - return [decl for decl in aten_decls if is_operator_selected_for_training(decl)] - - aten_decls = filter_decls(full_aten_decls, operator_selector) - +def gen_autograd( + aten_path: str, + native_functions_path: str, + out: str, + autograd_dir: str, + operator_selector: SelectiveBuilder, + disable_autograd: bool = False, +) -> None: # Parse and load derivatives.yaml from .load_derivatives import load_derivatives differentiability_infos = load_derivatives( @@ -175,13 +102,13 @@ def is_operator_selected_for_training(decl): template_path = os.path.join(autograd_dir, 'templates') # Generate VariableType.h/cpp + from .gen_trace_type import gen_trace_type + from .gen_variable_type import gen_variable_type if not disable_autograd: - from .gen_variable_type import gen_variable_type - gen_variable_type(out, aten_decls, differentiability_infos, template_path) + gen_variable_type(out, native_functions_path, differentiability_infos, template_path, operator_selector) - from . import gen_trace_type # operator filter not applied as tracing sources are excluded in selective build - gen_trace_type.gen_trace_type(out, native_functions_path, template_path) + gen_trace_type(out, native_functions_path, template_path) # Generate Functions.h/cpp from .gen_autograd_functions import gen_autograd_functions_lib @@ -193,7 +120,12 @@ def is_operator_selected_for_training(decl): gen_variable_factories(out, native_functions_path, template_path) -def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir): +def gen_autograd_python( + aten_path: str, + native_functions_path: str, + out: str, + autograd_dir: str, +) -> None: from .load_derivatives import load_derivatives differentiability_infos = load_derivatives( os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path) @@ -212,7 +144,7 @@ def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir): out, native_functions_path, deprecated_path, template_path) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Generate autograd C++ files script') parser.add_argument('declarations', metavar='DECL', diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 31eb8aacf296dd..d8d42762e4fb0a 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -422,7 +422,7 @@ def gen_trace_type_shard( fm: FileManager, native_functions: Sequence[NativeFunction], suffix: str ) -> None: fm.write_with_template('TraceType%s.cpp' % suffix, 'TraceType.cpp', lambda: { - 'generated_comment': f'@generated from {fm.template_dir}/TraceType.cpp', + 'generated_comment': '@' + f'generated from {fm.template_dir}/TraceType.cpp', 'trace_method_definitions': list(mapMaybe(method_definition, native_functions)), 'trace_wrapper_registrations': list(mapMaybe(method_registration, native_functions)), }) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 72be5b993f44bf..f49f5e15845bc7 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -22,20 +22,24 @@ # which will in turn dispatch back to VariableType for its # differentiable subcomponents. # +from dataclasses import dataclass -from .utils import CodeTemplate, nested_dict, write, make_out_api_name_faithful from .gen_autograd import VIEW_FUNCTIONS, VIEW_FUNCTIONS_WITH_METADATA_CHANGE, \ MULTI_OUTPUT_SAFE_FUNCTIONS, RETURNS_VIEWS_OF_INPUT from .gen_autograd_functions import uses_single_grad -from .gen_trace_type import MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, MANUAL_AUTOGRAD +from .gen_trace_type import ( + MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, MANUAL_AUTOGRAD, + declare_returned_variables, tie_return_values, get_return_value, type_wrapper_name, +) from tools.codegen.api.types import * from tools.codegen.api.autograd import * import tools.codegen.api.cpp as cpp -import tools.codegen.api.python as python -from tools.codegen.gen import with_native_function +from tools.codegen.code_template import CodeTemplate +from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager, mapMaybe from tools.codegen.model import * -from typing import Dict, Optional, List, Sequence, Any, Callable +from tools.codegen.selective_build.selector import SelectiveBuilder +from typing import Callable, List, Optional, Sequence, Tuple, Union # We don't set or modify grad_fn on these methods. Generally, they return # tensors that have requires_grad=False. In-place functions listed here will @@ -209,9 +213,6 @@ UNPACK_TENSOR = CodeTemplate("""\ auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""") -LEGACY_WRAP_OPTIONS = CodeTemplate("""\ -auto ${arg_name}_ = TensorOptions(${arg_name});""") - DECLARE_GRAD_FN = CodeTemplate("""\ std::shared_ptr<${op}> grad_fn; """) @@ -304,49 +305,18 @@ #endif """) -# Methods shared by TraceType and VariableType to handle return variable declaration, tie and tuple. -def format_return_variables(declaration): - name = declaration['name'] - arguments = declaration['arguments'] - inplace = declaration['inplace'] - is_out_fn = name.endswith('_out') - modifies_arguments = inplace or is_out_fn - - def declare_returned_variables(): - if modifies_arguments: - return '' - if len(declaration['returns']) == 1: - return '' - # TODO: this will be ugly - names = [ret['type'] + ' ' + ret['name'] + ';' for ret in declaration['returns']] - return '\n'.join(names) - - def tie_return_values(): - if len(declaration['returns']) == 1: - return 'auto {}'.format(declaration['returns'][0]['name']) - names = [ret['name'] for ret in declaration['returns']] - return 'std::tie({})'.format(', '.join(names)) - - def get_return_value(): - if inplace: - return 'self' - if is_out_fn: - return_names = [arg['name'] for arg in arguments - if arg.get('output', False)] - if len(return_names) == 1: - return return_names[0] - return 'std::forward_as_tuple({})'.format(', '.join(return_names)) - - returns = declaration['returns'] - if len(returns) == 1: - return returns[0]['name'] - moved = ['std::move({})'.format(r['name']) for r in returns] - return 'std::make_tuple({})'.format(', '.join(moved)) - - return (declare_returned_variables(), tie_return_values(), get_return_value()) +@dataclass(frozen=True) +class NativeFunctionWithDifferentiabilityInfo: + func: NativeFunction + info: Optional[DifferentiabilityInfo] - -def gen_variable_type(out, aten_declarations, differentiability_infos, template_path): +def gen_variable_type( + out: str, + native_yaml_path: str, + differentiability_infos: Sequence[DifferentiabilityInfo], + template_path: str, + operator_selector: SelectiveBuilder, +) -> None: """VariableType.h and VariableType.cpp body @@ -354,154 +324,202 @@ def gen_variable_type(out, aten_declarations, differentiability_infos, template_ implementation of each function dispatches to the base tensor type to compute the output. The grad_fn is attached to differentiable functions. """ + fns = list(sorted(filter( + operator_selector.is_native_function_selected_for_training, + parse_native_yaml(native_yaml_path)), key=lambda f: cpp.name(f.func))) + fns_with_infos = match_differentiability_info(fns, differentiability_infos) - aten_declarations = list(sorted(aten_declarations, key=lambda decl: decl['name'])) - match_declarations_with_differentiability_info(aten_declarations, differentiability_infos) - - gen_variable_type_shard(out, aten_declarations, template_path, None, True) + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h', 'VariableType.h') # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 5 - shards = [[] for _ in range(num_shards)] + shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)] # functions are assigned arbitrarily but stably to a file based on hash - for decl in aten_declarations: - x = sum(ord(c) for c in decl['name']) % num_shards - shards[x].append(decl) + for fn in fns_with_infos: + x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards + shards[x].append(fn) for i, shard in enumerate(shards): - gen_variable_type_shard(out, shard, template_path, '_%d' % i, False) - gen_variable_type_shard(out, aten_declarations, template_path, 'Everything', False) + gen_variable_type_shard(fm, shard, 'VariableType.cpp', f'VariableType_{i}.cpp') + gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp', 'VariableTypeEverything.cpp') -def gen_variable_type_shard(out, aten_declarations, template_path, suffix, header): - VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h') - VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp') - - type_declarations = [] - type_definitions = [] - wrapper_registrations = [] - - for declaration in aten_declarations: - if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']: - formals = declaration['schema_order_formals'] - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - formals = declaration['formals'] - type_declarations.append(METHOD_DECLARATION.substitute(declaration, formals=formals)) - strategy = dispatch_strategy(declaration) - if declaration['name'] not in MANUAL_AUTOGRAD and strategy == 'use_derived': - body = emit_body(declaration) +@with_native_function +def gen_formals(f: NativeFunction) -> str: + if f.use_c10_dispatcher.dispatcher_uses_new_style(): + formals = ', '.join( + f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ) + else: + sig_group = CppSignatureGroup.from_native_function(f, method=False) + formals = ', '.join(f'{a.type} {a.name}' for a in sig_group.signature.arguments()) + return formals +@with_native_function +def gen_wrapper_registration(f: NativeFunction) -> str: + if f.use_c10_dispatcher.dispatcher_uses_new_style(): + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type='VariableType', + ) + else: + return UNBOXEDONLY_WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type='VariableType', + ) + +def gen_variable_type_shard( + fm: FileManager, + fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], + template_name: str, + output_name: str, +) -> None: + type_declarations: List[str] = [] + type_definitions: List[str] = [] + wrapper_registrations: List[str] = [] + + for fn in fns_with_infos: + f = fn.func + name = cpp.name(f.func) + formals = gen_formals(f) + + type_declarations.append(METHOD_DECLARATION.substitute( + return_type=cpp.returns_type(f.func.returns), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + )) + + if name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == 'use_derived': type_definitions.append(METHOD_DEFINITION.substitute( - declaration, type_definition_body=body, formals=formals)) - if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']: - wrapper_registrations.append(WRAPPER_REGISTRATION.substitute( - declaration, class_type='VariableType')) - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - wrapper_registrations.append(UNBOXEDONLY_WRAPPER_REGISTRATION.substitute( - declaration, class_type='VariableType')) + return_type=cpp.returns_type(f.func.returns), + type_wrapper_name=type_wrapper_name(f), + type_definition_body=emit_body(fn), + formals=formals, + )) + wrapper_registrations.append(gen_wrapper_registration(f)) # See Note [Manual Backend kernels] - assert (declaration['name'] in MANUAL_BACKEND) == declaration['manual_kernel_registration'] + assert (name in MANUAL_BACKEND) == f.manual_kernel_registration # If you want to register a kernel to Autograd, you must make the op abstract. # In other words, this op must have dispatch section in native_functions.yaml. - if declaration['name'] in MANUAL_AUTOGRAD_AND_TRACER or declaration['derivative']: - msg = (f'There\'s a formula for {declaration["name"]}(or its functional variant) in derivatives.yaml. ' + if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info and fn.info.has_derivatives): + msg = (f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. ' f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA ' f'or DefaultBackend in native_functions.yaml. Please see ' f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword ' f'for instructions to choose the right dispatch keyword.') - assert declaration['abstract'], msg + assert f.is_abstract, msg - env = { + fm.write_with_template(output_name, template_name, lambda: { + 'generated_comment': '@' + f'generated from {fm.template_dir}/{template_name}', 'type_derived_method_declarations': type_declarations, 'type_derived_method_definitions': type_definitions, 'wrapper_registrations': wrapper_registrations, - } - if header: - write(out, 'VariableType.h', VARIABLE_TYPE_H, env) - else: - write(out, 'VariableType%s.cpp' % suffix, VARIABLE_TYPE_CPP, env) - - -def emit_body(declaration): - assert dispatch_strategy(declaration) == 'use_derived' - - arguments = declaration['arguments'] - returns = declaration['returns'] - func = declaration['derivative'] - name = declaration['name'] - inplace = declaration['inplace'] - is_out_fn = name.endswith('_out') - modifies_arguments = inplace or is_out_fn - returns_void = len(returns) == 0 - - base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name + }) + +def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: + assert dispatch_strategy(fn) == 'use_derived' + f = fn.func + info = fn.info + + name = cpp.name(f.func) + inplace = f.func.kind() == SchemaKind.inplace + is_out_fn = f.func.kind() == SchemaKind.out + returns_void = len(f.func.returns) == 0 + base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? view_info = VIEW_FUNCTIONS.get(base_name, None) if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: view_info = "self" - def is_differentiable(arg): - if 'TensorOptions' in arg['type']: - return False - if 'Tensor' not in arg['type']: - return False - if arg['name'] in declaration.get('non_differentiable_arg_names', []): - return False - return True - - def find_args_with_derivatives(differentiable_inputs): + def is_differentiable(name: str, type: Type) -> bool: + return type.is_tensor_like() and (info is None or name not in info.non_differentiable_arg_names) + + def gen_differentiable_input( + arg: Union[Argument, SelfArgument, TensorOptionsArguments] + ) -> Optional[DifferentiableInput]: + if isinstance(arg, TensorOptionsArguments): + return None + a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg + + # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove. + # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are + # not handled properly as they are irrelevant for this codegen. + cpp_type = cpp.argument_type(a, binds=a.name).cpp_type() + + if not is_differentiable(a.name, a.type): + return None + return DifferentiableInput( + name=a.name, + type=a.type, + cpp_type=cpp_type, + ) + + @with_native_function + def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]: + return list(mapMaybe(gen_differentiable_input, f.func.arguments.non_out)) + + def find_args_with_derivatives(differentiable_inputs: List[DifferentiableInput]) -> List[DifferentiableInput]: """Find arguments that have derivative definitions""" - if func is None: + if info is None or not info.has_derivatives: return differentiable_inputs - names = set(name for d in func.derivatives for name in d.var_names) - differentiable = [arg for arg in differentiable_inputs if arg['name'] in names] + names = set(name for d in info.derivatives for name in d.var_names) + differentiable = [arg for arg in differentiable_inputs if arg.name in names] if len(differentiable) != len(names): - missing = names - set(arg['name'] for arg in differentiable) - raise RuntimeError(f'Missing arguments for derivatives: {missing} in {func.name}') + missing = names - set(arg.name for arg in differentiable) + raise RuntimeError(f'Missing arguments for derivatives: {missing} in {info.name}') return differentiable - inputs = [arg for arg in arguments if not arg.get('output', False)] - differentiable_inputs = list(filter(is_differentiable, inputs)) + def gen_differentiable_outputs(f: NativeFunction) -> List[DifferentiableOutput]: + outputs: List[DifferentiableOutput] = [ + DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret)) + for name, ret in zip(cpp.return_names(f), f.func.returns)] + + output_differentiability = info.output_differentiability if info else None + if output_differentiability is not None: + differentiable_outputs: List[DifferentiableOutput] = [] + if False in output_differentiability and f.func.kind() == SchemaKind.inplace: + raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)") + for differentiable, output in zip(output_differentiability, outputs): + if differentiable: + differentiable_outputs.append(output) + return differentiable_outputs + + candidate_differentiable_outputs = list(filter(lambda r: is_differentiable(r.name, r.type), outputs)) + + if uses_single_grad(info): + return candidate_differentiable_outputs[:1] + else: + return candidate_differentiable_outputs + + differentiable_inputs = gen_differentiable_inputs(f) args_with_derivatives = find_args_with_derivatives(differentiable_inputs) - non_differentiable_arg_names = declaration.get('non_differentiable_arg_names', []) - candidate_differentiable_outputs = list(filter(is_differentiable, returns)) - - if declaration['output_differentiability'] is not None: - differentiable_outputs = [] - output_differentiability = declaration['output_differentiability'] - if False in output_differentiability and inplace: - raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)") - for differentiable, output in zip(output_differentiability, returns): - if differentiable: - differentiable_outputs.append(output) - elif uses_single_grad(func): - differentiable_outputs = candidate_differentiable_outputs[:1] - else: - differentiable_outputs = candidate_differentiable_outputs + differentiable_outputs = gen_differentiable_outputs(f) requires_derivative = ( base_name not in DONT_REQUIRE_DERIVATIVE and name not in DONT_REQUIRE_DERIVATIVE and len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0) - if func is not None and not requires_derivative: - raise RuntimeError('ERROR: derivative ignored for {} -- specified an autograd function without derivative' - .format(name)) + if info is not None and info.has_derivatives and not requires_derivative: + raise RuntimeError(f'ERROR: derivative ignored for {name} -- specified an autograd function without derivative') - def emit_save_inputs(): - setup = [] - if func is None: + def emit_save_inputs() -> List[str]: + setup: List[str] = [] + if info is None or not info.has_derivatives: return setup - has_tensorlist_arg = \ - any(arg.type in ['TensorList', 'const c10::List> &'] for arg in func.args_with_derivatives) + has_tensorlist_arg = any(is_tensor_list_type(arg.type) for arg in args_with_derivatives) # We don't want to save tensors if we know that they will never be used # when computing the derivative, so we add guards to those statements def guard_for(arg: SavedAttribute) -> Optional[str]: + assert info is not None + # It's hard to determine the edge offset if we have TensorLists if has_tensorlist_arg: return None @@ -512,12 +530,12 @@ def guard_for(arg: SavedAttribute) -> Optional[str]: # require_grad if the backward function even gets executed. I don't # have any good ideas for detecting those cases, so I simply disabled the # checks. - if 'backward' in func.name: + if 'backward' in info.name: return None # If there's a single derivative we could compute, we already have # a requires_grad check that is sufficient - if len(func.args_with_derivatives) <= 1: + if len(args_with_derivatives) <= 1: return None # We really only care about trimming down the amount of tensors we save @@ -526,7 +544,7 @@ def guard_for(arg: SavedAttribute) -> Optional[str]: # We want to emit simple guards, so we only allow that if checking one # input is enough to determine whether we need that value - used_in = [d for d in func.derivatives if arg in d.saved_inputs] + used_in = [d for d in info.derivatives if arg in d.saved_inputs] assert len(used_in) > 0 if len(used_in) != 1: return None @@ -536,75 +554,76 @@ def guard_for(arg: SavedAttribute) -> Optional[str]: derivative_var_name = derivative.var_names[0] # Figure out the offset of the edge that uses this variable - for edge_off, arg in enumerate(func.args_with_derivatives): - if arg.name == derivative_var_name: + for edge_off, a in enumerate(args_with_derivatives): + if a.name == derivative_var_name: break else: raise AssertionError() return f'grad_fn->should_compute_output({edge_off})' - setup.extend(save_variables(func.all_saved_inputs, False, guard_for)) - for arg in func.args_with_derivatives: - if arg.type in ['TensorList', 'const c10::List> &']: + setup.extend(save_variables(info.all_saved_inputs, False, guard_for)) + for arg in args_with_derivatives: + if is_tensor_list_type(arg.type): setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();') return setup - def setup_derivative(differentiable_inputs): - env = {} - env['args_with_derivatives'] = [arg['name'] for arg in args_with_derivatives] - env['op'] = func.op if func is not None else 'NotImplemented' - env['op_ctor'] = '' if func is not None else '"{}"'.format(declaration['api_name']) - + def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]: + body: List[str] = [] if is_out_fn: # For out functions, ensure that no input or output requires grad - body = [] body.append(DECLARE_GRAD_FN.substitute(op='Node')) body.append(SETUP_NONE_REQUIRES_GRAD.substitute( base_name=base_name, - args_to_check=[arg['name'] for arg in differentiable_inputs])) + args_to_check=[arg.name for arg in differentiable_inputs])) body.append(SETUP_NONE_REQUIRES_GRAD.substitute( base_name=base_name, - args_to_check=[arg['name'] for arg in differentiable_outputs])) + args_to_check=[arg.name for arg in differentiable_outputs])) return body + op = info.op if info is not None and info.has_derivatives else 'NotImplemented' setup = [] - setup.extend(ASSIGN_GRAD_FN.substitute(env).split('\n')) + setup.extend(ASSIGN_GRAD_FN.substitute( + op=op, + op_ctor='' if info is not None and info.has_derivatives else f'"{cpp.name(f.func)}"', + args_with_derivatives=[arg.name for arg in args_with_derivatives], + ).split('\n')) setup.extend(emit_save_inputs()) - body = [] body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)) - body.append(DECLARE_GRAD_FN.substitute(env)) + body.append(DECLARE_GRAD_FN.substitute(op=op)) body.append(SETUP_DERIVATIVE.substitute(setup=setup)) return body - def emit_check_if_in_complex_autograd_allowlist(): - body = [] + def emit_check_if_in_complex_autograd_allowlist() -> List[str]: + body: List[str] = [] if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: return body for arg in differentiable_outputs: - name = arg['name'] - if arg['type'] in ['Tensor', 'TensorList', 'const c10::List> &']: - body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name)) + name = arg.name + # TODO: should be `arg.type.is_tensor_like()`? + if arg.cpp_type in ['Tensor', 'TensorList', 'const c10::List> &']: + body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");') return body - def emit_check_no_requires_grad(tensor_args, args_with_derivatives): + def emit_check_no_requires_grad( + tensor_args: List[DifferentiableInput], + args_with_derivatives: List[DifferentiableInput], + ) -> List[str]: """Checks that arguments without derivatives don't require grad""" - body = [] + body: List[str] = [] for arg in tensor_args: if arg in args_with_derivatives: continue - name = arg['name'] - if name in non_differentiable_arg_names: + name = arg.name + if info and name in info.non_differentiable_arg_names: continue if name == 'output': # Double-backwards definitions sometimes take in 'input' and # 'output', but only define the derivative for input. continue - if arg['dynamic_type'] in {'IndexTensor', 'ByteTensor', 'BoolTensor'}: - continue - body.append('check_no_requires_grad({}, "{}");'.format(name, name)) + body.append(f'check_no_requires_grad({name}, "{name}");') return body def save_variables( @@ -644,42 +663,40 @@ def save_variables( stmts.append('}') return stmts - def emit_dispatch_call(api_name, input_base, unpacked_args): + def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str: """ Dispatch call via function in a namespace or method on Tensor.""" - if 'namespace' in declaration['method_of']: - if declaration['use_c10_dispatcher'] in ['hacky_wrapper_for_legacy_signatures', 'full']: - dispatcher_api_name = make_out_api_name_faithful(api_name) - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - dispatcher_api_name = api_name + if Variant.function in f.variants: call = CALL_DISPATCH_VIA_NAMESPACE.substitute( - api_name=dispatcher_api_name, + api_name=cpp.name( + f.func, + faithful_name_for_out_overloads=f.use_c10_dispatcher.dispatcher_uses_new_style(), + ), unpacked_args=unpacked_args) else: call = CALL_DISPATCH_VIA_METHOD.substitute( - api_name=api_name, + api_name=cpp.name(f.func), var=input_base, unpacked_method_args=unpacked_args[1:]) return call - def emit_view_lambda(): + def emit_view_lambda(unpacked_bindings: List[Binding]) -> str: """ Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.""" input_base = 'input_base' replay_view_func = '' - updated_unpacked_args = [] - combined = nested_dict(env, declaration) - known_view_arg_simple_types = ['int64_t', 'int64_t?', 'bool', 'IntArrayRef'] - for arg in combined['unpacked_args']: + updated_unpacked_args: List[str] = [] + known_view_arg_simple_types: List[str] = ['int64_t', 'c10::optional', 'bool', 'IntArrayRef'] + for unpacked_binding in unpacked_bindings: + arg, arg_type = unpacked_binding.name, unpacked_binding.type if arg == 'self_': updated_unpacked_args.append(input_base) continue - arg_type = combined['unpacked_args_simple_type'][arg] if arg_type not in known_view_arg_simple_types: - raise TypeError('You are adding an {} {} argument to op {} in addition to known types: {}. ' - 'Please update the list or materialize it so that it can be closed over by value, ' - 'also add a test in pytorch/xla/test/test_operations.py where this code is exercised.' - .format(arg_type, arg, declaration['name'], ', '.join(known_view_arg_simple_types))) + known_types_str = ', '.join(known_view_arg_simple_types) + raise TypeError(f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: ' + f'{known_types_str}. Please update the list or materialize it so that it can be closed ' + 'over by value, also add a test in pytorch/xla/test/test_operations.py where this code ' + 'is exercised.') if arg_type == 'IntArrayRef': # It's not safe to close over IntArrayRef by value, since this is a @@ -687,7 +704,7 @@ def emit_view_lambda(): arg_vec = arg + '_vec' replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) updated_unpacked_args.append(arg_vec) - elif arg_type == 'int64_t?': + elif arg_type == 'c10::optional': # Materialize int64_t? to int64_t arg_value = arg + '_val' replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg, val=arg_value, default='0') @@ -695,7 +712,7 @@ def emit_view_lambda(): else: updated_unpacked_args.append(arg) - replay_view_call = emit_dispatch_call(combined['api_name'], input_base, updated_unpacked_args) + replay_view_call = emit_dispatch_call(f, input_base, updated_unpacked_args) replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( input_base=input_base, replay_view_call=replay_view_call) @@ -706,17 +723,17 @@ def emit_view_lambda(): is_view_with_metadata_change=is_view_with_metadata_change, replay_view_func=replay_view_func) - def wrap_output(return_values, var): + def wrap_output(f: NativeFunction, unpacked_bindings: List[Binding], var: str) -> str: call = '' - rhs_value = None - if 'Tensor' not in declaration['return_type']: + rhs_value: Optional[str] = None + if not any(r.type.is_tensor_like() for r in f.func.returns): rhs_value = var elif view_info is not None: # See NOTE [ Autograd View Variables ] in variable.h for details. - differentiable_output_vars = {r['name'] for r in differentiable_outputs} + differentiable_output_vars = {r.name for r in differentiable_outputs} if not isinstance(view_info, str): - raise TypeError("The view info should be a string for {}, but it is: {}".format(base_name, view_info)) + raise TypeError(f'The view info should be a string for {base_name}, but it is: {view_info}') if len(differentiable_output_vars) == 0: # no output is differentiable (.indices() for SparseTensors for example) @@ -725,54 +742,55 @@ def wrap_output(return_values, var): # Single differentiable output (Tensor or Tensor[]) return_info = differentiable_outputs[0] # We only support simple Tensor or a TensorList for functions that return views - if not return_info['dynamic_type'] in ['Tensor', 'TensorList']: - raise RuntimeError("{} that return differentiable views can only return Tensor or Tensor[]".format(base_name)) + if not is_tensor_type(return_info.type) and not is_tensor_list_type(return_info.type): + raise RuntimeError(f'{base_name} that return differentiable views can only return Tensor or Tensor[]') # Only allow rebasing of the history if we return a single Tensor # If we are in a no grad block, raise a warning # See NOTE [ View + Inplace detection ] for more details about this logic - if return_info['dynamic_type'] in ['TensorList', 'const c10::List> &']: + if is_tensor_list_type(return_info.type): if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS: - creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE" + creation_meta = 'CreationMeta::MULTI_OUTPUT_SAFE' else: - creation_meta = "CreationMeta::MULTI_OUTPUT_NODE" - call += ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, " - "/* is_fw_differentiable */ true, " - "/* creation_meta */ {});").format(view_info, var, creation_meta) - rhs_value = 'std::move({})'.format(var) + creation_meta = 'CreationMeta::MULTI_OUTPUT_NODE' + call += (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, ' + '/* is_fw_differentiable */ true, ' + f'/* creation_meta */ {creation_meta});') + rhs_value = f'std::move({var})' else: - call += emit_view_lambda() - creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE" - rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, " - "/* is_fw_differentiable */ true, " - "/* view_func */ func, /* creation_meta */ {})").format(view_info, var, creation_meta) + call += emit_view_lambda(unpacked_bindings) + creation_meta = 'GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE' + rhs_value = (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, ' + '/* is_fw_differentiable */ true, ' + f'/* view_func */ func, /* creation_meta */ {creation_meta})') else: # This could be supported but we don't need it at the moment, so keeping things simple. - raise RuntimeError("Function that return multiple differentiable output " - "when at least one of them is view is not supported.") + raise RuntimeError('Function that return multiple differentiable output ' + 'when at least one of them is view is not supported.') else: - rhs_value = 'std::move({})'.format(var) + rhs_value = f'std::move({var})' assert rhs_value is not None - call += ASSIGN_RETURN_VALUE.substitute(return_values=return_values, + call += ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f), rhs_value=rhs_value) return call - def enforce_same_tensorimpl_and_storage(env, call): - save_ptrs_stmts = [] - enforce_same_ptrs_stmts = [] - if declaration['name'] not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: - for arg in env.get('unpacked_args', []): - simple_type = env['unpacked_args_simple_type'][arg] - if simple_type == 'TensorList': + def enforce_same_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: + save_ptrs_stmts: List[str] = [] + enforce_same_ptrs_stmts: List[str] = [] + if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: + for unpacked_binding in unpacked_bindings: + arg = unpacked_binding.name + noref_cpp_type = unpacked_binding.ctype.cpp_type(strip_ref=True) + if noref_cpp_type == 'TensorList': save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] - elif simple_type == 'c10::List>': + elif noref_cpp_type == 'c10::List>': save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] - elif simple_type == 'Tensor': + elif noref_cpp_type == 'Tensor': save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg), @@ -784,74 +802,69 @@ def enforce_same_tensorimpl_and_storage(env, call): RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts) return call - def emit_call(env, tie_return_values): - combined = nested_dict(env, declaration) + def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure # the baseType operations still dispatch to non-Variable type, even if the arguments passed # in are now Variables. # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. - base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args']) - if not modifies_arguments and not returns_void: + unpacked_args = [b.name for b in unpacked_bindings] + base_type_call = emit_dispatch_call(f, 'self_', unpacked_args) + if not modifies_arguments(f) and not returns_void: call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( base_type_call=base_type_call) - call += wrap_output(tie_return_values, 'tmp') + call += wrap_output(f, unpacked_bindings, 'tmp') else: call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( base_type_call=base_type_call) - call = enforce_same_tensorimpl_and_storage(env, call) + call = enforce_same_tensorimpl_and_storage(call, unpacked_bindings) return call - def emit_history(): - fn = 'rebase' if modifies_arguments and view_info is None else 'set' - output_names = [r['name'] for r in differentiable_outputs] + def emit_history() -> str: + fn = 'rebase' if modifies_arguments(f) and view_info is None else 'set' + output_names = [r.name for r in differentiable_outputs] # TODO: flatten allocates a std::vector, which could be expensive outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names) return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs) - def emit_save_outputs(): + def emit_save_outputs() -> str: if is_out_fn: # out functions don't currently support differentiation return '' - func = declaration['derivative'] - if func is not None: - stmts = save_variables(func.all_saved_outputs, True) + if info is not None and info.has_derivatives: + stmts = save_variables(info.all_saved_outputs, True) if len(stmts) == 0: return '' return CONDITIONAL.substitute(cond='grad_fn', statements=stmts) return '' - def emit_any_requires_grad(): + def emit_any_requires_grad() -> List[str]: return [SETUP_ANY_REQUIRES_GRAD.substitute( - args_with_derivatives=[arg['name'] for arg in args_with_derivatives]), ] + args_with_derivatives=[arg.name for arg in args_with_derivatives]), ] - def emit_check_inplace(): + def emit_check_inplace() -> List[str]: if not inplace: return [] - return ['check_inplace({}, _any_requires_grad);'.format(arg['name']) for arg in differentiable_outputs] + return [f'check_inplace({arg.name}, _any_requires_grad);' for arg in differentiable_outputs] - def emit_increment_version(): - if not modifies_arguments: + def emit_increment_version(f: NativeFunction) -> List[str]: + if not modifies_arguments(f): return [] - return ['increment_version({});'.format(arg['name']) for arg in returns] + return [f'increment_version({r});' for r in cpp.return_names(f)] - env = {} - combined = nested_dict(env, declaration) + body: List[str] = [] + unpack_args_stats, unpacked_bindings = unpack_args(f) - body = [] - - declare_returned_variables, tie_return_values, get_return_value = format_return_variables(declaration) - - body.extend(unpack_args(env, declaration)) + body.extend(unpack_args_stats) if requires_derivative: body.extend(emit_any_requires_grad()) body.extend(emit_check_inplace()) body.extend(setup_derivative(differentiable_inputs)) - body.append(declare_returned_variables) + body.append(declare_returned_variables(f)) - body.append(emit_call(env, tie_return_values)) - body.extend(emit_increment_version()) + body.append(emit_call(f, unpacked_bindings)) + body.extend(emit_increment_version(f)) if requires_derivative: # set_flags has to appear after version_counter, because rebase_history # requires that the counter is incremented before it is called @@ -866,56 +879,54 @@ def emit_increment_version(): assert inplace body.append('reset_grad_accumulator(self);') if not returns_void: - body.append('return {};'.format(get_return_value)) + body.append(f'return {get_return_value(f)};') return body - -def unpack_args(env, declaration): - def requires_unpack(arg): - return 'Tensor' in arg['dynamic_type'] and 'c10::optional' not in arg['type'] - - body = [] - unpacked_args = [] - unpacked_args_simple_type = {} - if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']: - arguments = declaration['schema_order_arguments'] +@with_native_function +def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]: + body: List[str] = [] + unpacked_bindings: List[Binding] = [] + + if f.use_c10_dispatcher.dispatcher_uses_new_style(): + bindings = [r for a in f.func.schema_order_arguments() + for r in cpp.argument(a, + method=False, + cpp_no_default_args=set(), + faithful=False, + has_tensor_options=False)] else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - arguments = declaration['arguments'] - for i, arg in enumerate(arguments): - if not requires_unpack(arg): - unpacked_args.append(arg['name']) - unpacked_args_simple_type[arg['name']] = arg['simple_type'] - continue + sig_group = CppSignatureGroup.from_native_function(f, method=False) + bindings = list(sig_group.signature.arguments()) - dynamic_type = arg['dynamic_type'] - if 'TensorOptions' not in dynamic_type: - is_nullable = arg.get('is_nullable', False) - ref = (not is_nullable) and dynamic_type != 'TensorList' - suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else '' - body.append(UNPACK_TENSOR.substitute( - arg_name=arg['name'], - arg_pos=i, - suffix=suffix, - ref='&' if ref else '', - )) - else: - # Okay, we are abusing the definition of 'unpack' here a bit, - # although it's still getting the non-variable from the variable - # (in this case via TensorOptions rather than Variable/Tensor). - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper', \ - "VariableKernel shouldn't take TensorOptions if the op is c10-full" - body.append(LEGACY_WRAP_OPTIONS.substitute(arg_name=arg['name'])) - - unpacked_args.append(arg['name'] + '_') - unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type'] - - env['unpacked_args'] = unpacked_args - env['unpacked_args_simple_type'] = unpacked_args_simple_type - return body + for i, binding in enumerate(bindings): + assert not isinstance(binding.argument, SelfArgument) + if isinstance(binding.argument, TensorOptionsArguments): + raise RuntimeError("VariableKernel shouldn't take TensorOptions") + is_nullable = binding.argument.type.is_nullable() + if not binding.argument.type.is_tensor_like() or is_nullable: + unpacked_bindings.append(binding) + continue -def dispatch_strategy(declaration): + is_tensor_list = is_tensor_list_type(binding.argument.type) + ref = (not is_nullable) and not is_tensor_list + suffix = '_opt' if is_nullable and not is_tensor_list else '' + body.append(UNPACK_TENSOR.substitute( + arg_name=binding.name, + arg_pos=i, + suffix=suffix, + ref='&' if ref else '', + )) + unpacked_bindings.append(Binding( + name=binding.name + '_', + ctype=binding.ctype, + argument=binding.argument, + default=binding.default, + )) + + return body, unpacked_bindings + +def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: """How are we going to call the underlying implementation of a declaration? There are two strategies: @@ -935,7 +946,7 @@ def dispatch_strategy(declaration): get dispatched back to VariableType (which will ensure that they are differentiable.) """ - if declaration['abstract'] or declaration['derivative'] is not None: + if fn.func.is_abstract or (fn.info is not None and fn.info.has_derivatives): # If the function is abstract (not implemented on at::Type), we must # call the implementation on the derived type with unpacked tensors. @@ -959,62 +970,47 @@ def dispatch_strategy(declaration): # assumption might not hold, but then you'll see gradcheck fail.) return 'use_type' -def get_decl_signature(declaration: Dict[Any, Any], use_base_variant: bool = False) -> str: - name = declaration['name'] - arguments = declaration['arguments'] - if use_base_variant: - if declaration['inplace']: - assert name.endswith('_') - name = name[:-1] - elif name.endswith('_out'): - name = name[:-4] - arguments = [arg for arg in arguments if not arg.get('output', False)] - simple_types = ', '.join(arg['simple_type'] for arg in arguments) - return f'{name}({simple_types})' +def is_tensor_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is None -@with_native_function -def get_func_signature(f: NativeFunction) -> str: - args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments() - types = ', '.join(python.argument_type_str(a.argument.type, simple_type=True) - if isinstance(a.argument, Argument) else 'TensorOptions' - for a in args) - return f'{cpp.name(f.func)}({types})' - -def match_declarations_with_differentiability_info( - declarations: Dict[Any, Any], +def is_tensor_list_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is not None + +def modifies_arguments(f: NativeFunction) -> bool: + return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] + +def match_differentiability_info( + native_functions: List[NativeFunction], differentiability_infos: Sequence[DifferentiabilityInfo], -) -> None: +) -> List[NativeFunctionWithDifferentiabilityInfo]: """Sets the "derivative" key on declarations to matching autograd function In-place functions will use the out-of-place derivative definition if there is no in-place specific derivative. """ - info_by_signature = {get_func_signature(info.func): info for info in differentiability_infos} + info_by_schema = {info.func.func: info for info in differentiability_infos} + functional_info_by_signature = { + info.func.func.signature(strip_default=True): info + for info in differentiability_infos + if info.func.func.kind() == SchemaKind.functional} - def find_info(declaration: Dict[Any, Any]) -> Optional[DifferentiabilityInfo]: - signature = get_decl_signature(declaration) - if signature in info_by_signature: - return info_by_signature[signature] + def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]: + if f.func in info_by_schema: + return info_by_schema[f.func], True # if there is no exact match look for the out-of-place signature. # i.e mul() for mul_() or mul_out() - signature = get_decl_signature(declaration, use_base_variant=True) - return info_by_signature.get(signature) - - for declaration in declarations: - info = find_info(declaration) - declaration['derivative'] = info if info and info.args_with_derivatives else None - - # Currently, the '.strides()' to 'strides_or_error' replacement does not support - # 'self' derivatives of an inplace function, so we must check for this case. - if declaration['inplace'] and (info is not None): - for derivative in info.derivatives: - if 'self' in derivative.var_names: - for saved_input in derivative.saved_inputs: - assert 'strides_or_error' not in saved_input.expr, ( - "Calling '.strides()' in the 'self' derivative formula of an " - f"in-place function is not supported: {declaration['name']}") - - declaration['non_differentiable_arg_names'] = info.non_differentiable_arg_names if info else [] - declaration['output_differentiability'] = info.output_differentiability if info else None + return functional_info_by_signature.get(f.func.signature(strip_default=True)), False + + result: List[NativeFunctionWithDifferentiabilityInfo] = [] + for f in native_functions: + info, is_exact_match = find_info(f) + result.append(NativeFunctionWithDifferentiabilityInfo( + func=f, + info=info, + )) + + return result diff --git a/tools/codegen/api/autograd.py b/tools/codegen/api/autograd.py index 58fb75bb7c0708..6f58eea6d1eac3 100644 --- a/tools/codegen/api/autograd.py +++ b/tools/codegen/api/autograd.py @@ -87,3 +87,36 @@ class DifferentiabilityInfo: # Raw data read from derivatives.yaml. output_differentiability: Optional[List[bool]] + + @property + def has_derivatives(self) -> bool: + return len(self.args_with_derivatives) > 0 + +# Represents a differentiable `Argument`. +# How is it different from the `Argument` type? +# - It's processed Arguments which are differentiable and only used in the +# context of the autograd codegen; +# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; +@dataclass(frozen=True) +class DifferentiableInput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + +# Represents a differentiable `Return`. +# How it it different from the `Return` type? +# - The name in `Return` is optional. Here it is always populated using the same +# `cpp.return_names()` method. +# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? +# - It's processed Returns which are differentiable, in compliance with the +# `output_differentiability` field defined in derivatives.yaml (if specified), +# and are only used in the context of the autograd codegen; +@dataclass(frozen=True) +class DifferentiableOutput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index 29a29e215f4f8f..8a1d2a5272f534 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -106,7 +106,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: return BaseCType("DimnameList", binds) elif str(t.elem) == 'Tensor?': if local.use_c10_dispatcher().dispatcher_uses_new_style(): - return BaseCType("const c10::List> &", binds) + return ConstRefCType(BaseCType("c10::List>", binds)) else: return BaseCType("TensorList", binds) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index ea03a1799cfb2b..39fb8bef384649 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -31,14 +31,16 @@ class BaseCType: type: str name: ArgName - def cpp_type(self) -> str: + def cpp_type(self, *, strip_ref: bool = False) -> str: return self.type @dataclass(frozen=True) class ConstRefCType: elem: 'CType' - def cpp_type(self) -> str: + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) return f'const {self.elem.cpp_type()} &' @property @@ -49,7 +51,9 @@ def name(self) -> ArgName: class MutRefCType: elem: 'CType' - def cpp_type(self) -> str: + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) return f'{self.elem.cpp_type()} &' @property @@ -60,7 +64,8 @@ def name(self) -> ArgName: class OptionalCType: elem: 'CType' - def cpp_type(self) -> str: + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. return f'c10::optional<{self.elem.cpp_type()}>' @property diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 782d8b919e7e5a..8f521e6651bcc8 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -203,8 +203,7 @@ class RegisterSchema: @method_with_native_function def __call__(self, f: NativeFunction) -> Optional[str]: - op_name = f"aten::{f.func.name}" - if not self.selector.is_operator_selected(op_name): + if not self.selector.is_native_function_selected(f): return None return f'm.def({cpp_string(str(f.func))});\n' @@ -399,8 +398,7 @@ def gen_one(f: NativeFunction) -> Optional[str]: e.expr for e in translate(functional_sig.arguments(), dispatcher.arguments(functional_func), method=False) ) - op_name = f"aten::{f.func.name}" - if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name): + if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): return None k = f.func.kind() @@ -480,8 +478,7 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]: if f.manual_kernel_registration: return None - op_name = f"aten::{f.func.name}" - if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name): + if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): return None name = native.name(f.func) diff --git a/tools/codegen/model.py b/tools/codegen/model.py index ea667a0922cf64..9c8a0d73e81543 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -567,7 +567,7 @@ def kind(self) -> SchemaKind: else: return SchemaKind.functional - def signature(self) -> 'FunctionSchema': + def signature(self, *, strip_default: bool = False) -> 'FunctionSchema': """ Certain schemas are 'related', in that they are simply inplace/out/functional versions of the same function. This method @@ -582,11 +582,13 @@ def signature(self) -> 'FunctionSchema': - Out arguments are stripped - Mutability annotations are stripped (this is sound because you cannot overload on mutability annotation) + - Return names are stripped since they are not overloadable and + some variants have return names but some not """ def strip_ret_annotation(r: Return) -> Return: return Return( - name=r.name, + name=None, type=r.type, annotation=None, ) @@ -600,7 +602,7 @@ def strip_ret_annotation(r: Return) -> Return: ), overload_name="", # stripped ), - arguments=self.arguments.signature(), + arguments=self.arguments.signature(strip_default=strip_default), returns=tuple(map(strip_ret_annotation, self.returns)), ) @@ -983,14 +985,14 @@ def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: ret.extend(self.post_tensor_options_kwarg_only) return ret - def signature(self) -> 'Arguments': + def signature(self, *, strip_default: bool = False) -> 'Arguments': # dataclasses.replace could be used here, but it is less # type safe so for now I've opted to type everything out def strip_arg_annotation(a: Argument) -> Argument: return Argument( name=a.name, type=a.type, - default=a.default, # hmmm + default=a.default if not strip_default else None, annotation=None, ) diff --git a/tools/codegen/selective_build/selector.py b/tools/codegen/selective_build/selector.py index 24e387128b6ca9..3e80e168d31c04 100644 --- a/tools/codegen/selective_build/selector.py +++ b/tools/codegen/selective_build/selector.py @@ -3,6 +3,7 @@ from dataclasses import dataclass +from tools.codegen.model import NativeFunction from tools.codegen.selective_build.operator import * # A SelectiveBuilder holds information extracted from the selective build @@ -96,6 +97,10 @@ def is_operator_selected(self, name: str) -> bool: name = strip_operator_overload_name(name) return name in self.operators and self.operators[name].include_all_overloads + def is_native_function_selected(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected(op_name) + def is_operator_selected_for_training(self, name: str) -> bool: if not self.is_operator_selected(name): return False @@ -123,6 +128,10 @@ def is_operator_selected_for_training(self, name: str) -> bool: (base_op.include_all_overloads and base_op.is_used_for_training) ) + def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected_for_training(op_name) + def is_root_operator(self, name: str) -> bool: if not self.is_operator_selected(name): return False @@ -158,3 +167,9 @@ def combine_selective_builders(lhs: SelectiveBuilder, rhs: SelectiveBuilder) -> debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) operators = merge_operator_dicts(lhs.operators, rhs.operators) return SelectiveBuilder(include_all_operators, debug_info, operators) + + +def op_name_from_native_function(f: NativeFunction) -> str: + # This was originally read from the 'operator_name_with_overload' field in the + # declaration dict, which was the part before the first '(' in 'schema_string'. + return f'aten::{f.func.name}' diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py index 267b5a3b221a0e..a52c109c603f4f 100644 --- a/tools/jit/gen_unboxing_wrappers.py +++ b/tools/jit/gen_unboxing_wrappers.py @@ -22,9 +22,10 @@ import re from itertools import groupby from functools import reduce -from ..autograd.gen_autograd import load_aten_declarations +import yaml + from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT -from ..autograd.utils import CodeTemplate, write, is_out_variant, op_name_with_overload +from ..autograd.utils import CodeTemplate, YamlLoader, write, is_out_variant, op_name_with_overload from tools.codegen.selective_build.selector import SelectiveBuilder # JIT has a type system of @@ -279,6 +280,66 @@ def argument_order(decl): return decl.get('jit_argument_order') or list(range(len(decl['arguments']))) +def format_return_type(returns): + if len(returns) == 0: + return 'void' + elif len(returns) == 1: + return returns[0]['type'] + else: + return_types = [r['type'] for r in returns] + return 'std::tuple<{}>'.format(','.join(return_types)) + + +def get_simple_type(arg): + simple_type = arg['type'] + simple_type = simple_type.replace(' &', '').replace('const ', '') + simple_type = simple_type.replace('Generator *', 'Generator') + + opt_match = re.match(r'c10::optional<(.+)>', simple_type) + if opt_match: + simple_type = '{}?'.format(opt_match.group(1)) + return simple_type + + +def load_aten_declarations(path): + with open(path, 'r') as f: + declarations = yaml.load(f, Loader=YamlLoader) + + # enrich declarations with additional information + selected_declarations = [] + for declaration in declarations: + if declaration.get('deprecated'): + continue + + for arg in declaration['arguments']: + arg['simple_type'] = get_simple_type(arg) + for arg in declaration['schema_order_arguments']: + arg['simple_type'] = get_simple_type(arg) + for ret in declaration['returns']: + ret['simple_type'] = get_simple_type(ret) + + declaration['formals'] = [arg['type'] + ' ' + arg['name'] + for arg in declaration['arguments']] + declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name'] + for arg in declaration['schema_order_arguments']] + declaration['args'] = [arg['name'] for arg in declaration['arguments']] + declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']] + declaration['api_name'] = declaration['name'] + if declaration.get('overload_name'): + declaration['type_wrapper_name'] = "{}_{}".format( + declaration['name'], declaration['overload_name']) + else: + declaration['type_wrapper_name'] = declaration['name'] + declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0] + declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1] + declaration['return_type'] = format_return_type(declaration['returns']) + + declaration['base_name'] = declaration['name'] + selected_declarations.append(declaration) + + return selected_declarations + + def gen_unboxing_wrappers( declarations, out,