Skip to content

Commit

Permalink
[pytorch][codegen] migrate gen_variable_type to new data model (pytor…
Browse files Browse the repository at this point in the history
…ch#49735)

Summary:
Pull Request resolved: pytorch#49735

This is the final wave of autograd codegen data model migration.

After this PR:
- autograd codegen no longer depends on Declarations.yaml;
- autograd codegen sources are fully type annotated and pass mypy-strict check;

To avoid potential merge conflicts with other pending PRs, some structural
changes are intentionally avoided, e.g. didn't move inner methods out, didn't
change all inner methods to avoid reading outer function's variables, and etc.

Confirmed byte-for-byte compatible with the old codegen:
```
Run it before and after this PR:
  .jenkins/pytorch/codegen-test.sh <baseline_output_dir>
  .jenkins/pytorch/codegen-test.sh <test_output_dir>

Then run diff to compare the generated files:
  diff -Naur <baseline_output_dir> <test_output_dir>
```

Confirmed clean mypy-strict run:
```
mypy --config mypy-strict.ini
```

Test Plan: Imported from OSS

Reviewed By: ezyang, bhosmer

Differential Revision: D25678879

Pulled By: ljk53

fbshipit-source-id: ba6e2eb6b9fb744208f7f79a922d933fcc3bde9f
  • Loading branch information
ljk53 authored and facebook-github-bot committed Jan 5, 2021
1 parent a272a7e commit e71a13e
Show file tree
Hide file tree
Showing 11 changed files with 507 additions and 464 deletions.
2 changes: 2 additions & 0 deletions mypy-strict.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ strict_equality = True

files = tools/codegen/gen.py,
tools/autograd/gen_annotated_fn_args.py,
tools/autograd/gen_autograd.py,
tools/autograd/gen_python_functions.py,
tools/autograd/gen_trace_type.py,
tools/autograd/gen_variable_factories.py,
tools/autograd/gen_variable_type.py,
tools/autograd/load_derivatives.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,
Expand Down
106 changes: 19 additions & 87 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

import argparse
import os
import yaml
import re
from .utils import YamlLoader, op_name_with_overload
from tools.codegen.selective_build.selector import SelectiveBuilder

# See NOTE [ Autograd View Variables ] in variable.h for details.
Expand Down Expand Up @@ -89,84 +86,14 @@
'tensor_split', 'swapdims', 'swapaxes'
})

def format_return_type(returns):
if len(returns) == 0:
return 'void'
elif len(returns) == 1:
return returns[0]['type']
else:
return_types = [r['type'] for r in returns]
return 'std::tuple<{}>'.format(','.join(return_types))


def get_simple_type(arg):
simple_type = arg['type']
simple_type = simple_type.replace(' &', '').replace('const ', '')
simple_type = simple_type.replace('Generator *', 'Generator')

opt_match = re.match(r'c10::optional<(.+)>', simple_type)
if opt_match:
simple_type = '{}?'.format(opt_match.group(1))
return simple_type

def has_tensoroptions_argument(declaration):
for argument in declaration['arguments']:
if 'TensorOptions' == argument['dynamic_type']:
return True
return False


def load_aten_declarations(path):
with open(path, 'r') as f:
declarations = yaml.load(f, Loader=YamlLoader)

# enrich declarations with additional information
selected_declarations = []
for declaration in declarations:
if declaration.get('deprecated'):
continue

for arg in declaration['arguments']:
arg['simple_type'] = get_simple_type(arg)
for arg in declaration['schema_order_arguments']:
arg['simple_type'] = get_simple_type(arg)
for ret in declaration['returns']:
ret['simple_type'] = get_simple_type(ret)

declaration['formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['schema_order_arguments']]
declaration['args'] = [arg['name'] for arg in declaration['arguments']]
declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']]
declaration['api_name'] = declaration['name']
if declaration.get('overload_name'):
declaration['type_wrapper_name'] = "{}_{}".format(
declaration['name'], declaration['overload_name'])
else:
declaration['type_wrapper_name'] = declaration['name']
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
declaration['return_type'] = format_return_type(declaration['returns'])

declaration['base_name'] = declaration['name']
selected_declarations.append(declaration)

return selected_declarations


def gen_autograd(aten_path, native_functions_path, out, autograd_dir, operator_selector: SelectiveBuilder, disable_autograd=False):
full_aten_decls = load_aten_declarations(aten_path)

def filter_decls(aten_decls, operator_selector):
def is_operator_selected_for_training(decl):
op_name = op_name_with_overload(decl)
return operator_selector.is_operator_selected_for_training(op_name)

return [decl for decl in aten_decls if is_operator_selected_for_training(decl)]

aten_decls = filter_decls(full_aten_decls, operator_selector)

def gen_autograd(
aten_path: str,
native_functions_path: str,
out: str,
autograd_dir: str,
operator_selector: SelectiveBuilder,
disable_autograd: bool = False,
) -> None:
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
Expand All @@ -175,13 +102,13 @@ def is_operator_selected_for_training(decl):
template_path = os.path.join(autograd_dir, 'templates')

# Generate VariableType.h/cpp
from .gen_trace_type import gen_trace_type
from .gen_variable_type import gen_variable_type
if not disable_autograd:
from .gen_variable_type import gen_variable_type
gen_variable_type(out, aten_decls, differentiability_infos, template_path)
gen_variable_type(out, native_functions_path, differentiability_infos, template_path, operator_selector)

from . import gen_trace_type
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type.gen_trace_type(out, native_functions_path, template_path)
gen_trace_type(out, native_functions_path, template_path)

# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_lib
Expand All @@ -193,7 +120,12 @@ def is_operator_selected_for_training(decl):
gen_variable_factories(out, native_functions_path, template_path)


def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
def gen_autograd_python(
aten_path: str,
native_functions_path: str,
out: str,
autograd_dir: str,
) -> None:
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
Expand All @@ -212,7 +144,7 @@ def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
out, native_functions_path, deprecated_path, template_path)


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Generate autograd C++ files script')
parser.add_argument('declarations', metavar='DECL',
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_trace_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def gen_trace_type_shard(
fm: FileManager, native_functions: Sequence[NativeFunction], suffix: str
) -> None:
fm.write_with_template('TraceType%s.cpp' % suffix, 'TraceType.cpp', lambda: {
'generated_comment': f'@generated from {fm.template_dir}/TraceType.cpp',
'generated_comment': '@' + f'generated from {fm.template_dir}/TraceType.cpp',
'trace_method_definitions': list(mapMaybe(method_definition, native_functions)),
'trace_wrapper_registrations': list(mapMaybe(method_registration, native_functions)),
})
Expand Down
Loading

0 comments on commit e71a13e

Please sign in to comment.