diff --git a/.lintrunner.toml b/.lintrunner.toml index 39f2c20a5b201f..ed18baffbd8d35 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -908,62 +908,6 @@ exclude_patterns = [ 'third_party/**/*.pyi', # These files are all grandfathered in, feel free to remove from this list # as necessary - 'aten/src/ATen/function_wrapper.py', - 'aten/src/ATen/native/quantized/cpu/qnnpack/configure.py', - 'aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py', - 'aten/src/ATen/native/quantized/cpu/qnnpack/generate-wrapper.py', - 'aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py', - 'aten/src/ATen/nnapi/codegen.py', - 'functorch/__init__.py', - 'functorch/_src/__init__.py', - 'functorch/_src/aot_autograd/__init__.py', - 'functorch/_src/eager_transforms/__init__.py', - 'functorch/_src/make_functional/__init__.py', - 'functorch/_src/vmap/__init__.py', - 'functorch/benchmarks/chrome_trace_parser.py', - 'functorch/benchmarks/cse.py', - 'functorch/benchmarks/operator_authoring.py', - 'functorch/benchmarks/per_sample_grads.py', - 'functorch/benchmarks/pointwise_scorecard.py', - 'functorch/benchmarks/process_scorecard.py', - 'functorch/compile/__init__.py', - 'functorch/dim/__init__.py', - 'functorch/dim/batch_tensor.py', - 'functorch/dim/delayed_mul_tensor.py', - 'functorch/dim/dim.py', - 'functorch/dim/magic_trace.py', - 'functorch/dim/op_properties.py', - 'functorch/dim/reference.py', - 'functorch/dim/tree_map.py', - 'functorch/dim/wrap_type.py', - 'functorch/docs/source/conf.py', - 'functorch/einops/__init__.py', - 'functorch/einops/_parsing.py', - 'functorch/einops/rearrange.py', - 'functorch/examples/compilation/eager_fusion.py', - 'functorch/examples/compilation/fuse_module.py', - 'functorch/examples/compilation/linear_train.py', - 'functorch/examples/compilation/simple_function.py', - 'functorch/examples/dp_cifar10/cifar10_opacus.py', - 'functorch/examples/dp_cifar10/cifar10_transforms.py', - 'functorch/examples/ensembling/parallel_train.py', - 'functorch/examples/lennard_jones/lennard_jones.py', - 'functorch/examples/maml_omniglot/maml-omniglot-higher.py', - 'functorch/examples/maml_omniglot/maml-omniglot-ptonly.py', - 'functorch/examples/maml_omniglot/maml-omniglot-transforms.py', - 'functorch/examples/maml_omniglot/support/omniglot_loaders.py', - 'functorch/examples/maml_regression/evjang.py', - 'functorch/examples/maml_regression/evjang_transforms.py', - 'functorch/examples/maml_regression/evjang_transforms_module.py', - 'functorch/experimental/__init__.py', - 'functorch/experimental/_cond.py', - 'functorch/experimental/_map.py', - 'functorch/experimental/control_flow.py', - 'functorch/experimental/ops.py', - 'functorch/notebooks/_src/plot_ensembling.py', - 'functorch/notebooks/_src/plot_jacobians_and_hessians.py', - 'functorch/notebooks/_src/plot_per_sample_gradients.py', - 'functorch/op_analysis/gen_data.py', 'test/_nvfuser/__init__.py', 'test/_nvfuser/test_dynamo.py', 'test/_nvfuser/test_python_frontend.py', diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py b/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py index 0c683f414334b9..43f7841d55c9b6 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py @@ -31,7 +31,6 @@ def main(args): ], extra_include_dirs="src", ): - requantization_objects = [ build.cc("requantization/precise-scalar.c"), build.cc("requantization/fp32-scalar.c"), @@ -192,7 +191,6 @@ def main(args): }, extra_include_dirs=["src", "test"], ): - build.unittest("hgemm-test", build.cxx("hgemm.cc")) build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc")) build.unittest("q8conv-test", build.cxx("q8conv.cc")) @@ -252,7 +250,6 @@ def main(args): isa=benchmark_isa, extra_include_dirs="src", ): - build.benchmark("add-bench", build.cxx("add.cc")) build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc")) build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc")) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py index f9bd42d3a9872d..852e0454972f89 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py @@ -7,6 +7,7 @@ # LICENSE file in the root directory of this source tree. import confu + parser = confu.standard_parser("clog configuration script") @@ -19,13 +20,16 @@ def main(args): with build.options(source_dir="src", extra_include_dirs="src"): build.static_library("clog", build.cc("clog.c")) - with build.options(source_dir="test", deps={ - (build, build.deps.googletest): all, - "log": build.target.is_android}): + with build.options( + source_dir="test", + deps={(build, build.deps.googletest): all, "log": build.target.is_android}, + ): build.unittest("clog-test", build.cxx("clog.cc")) return build + if __name__ == "__main__": import sys + main(sys.argv[1:]).generate() diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py index 3dcbdc35b5511d..a010d1e98678d1 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py @@ -8,12 +8,12 @@ # Kernels are ordered (see `sort_index`), and when dispatching, # we select the first kernel in the list that supports the inputs +import argparse import collections import itertools from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple, TypeVar -import argparse DTYPES = { "f32": "float", @@ -303,7 +303,11 @@ def get_all(cls) -> List["BwdKernel"]: def write_decl_impl( - kernels: List[T], family_name: str, impl_file: str, autogen_dir: Path, disable_def: str = None + kernels: List[T], + family_name: str, + impl_file: str, + autogen_dir: Path, + disable_def: str = None, ) -> None: cpp_file_header = """/* * Copyright (c) Meta Platforms, Inc. and affiliates. @@ -382,22 +386,28 @@ def main(output_dir: Optional[str]) -> None: FwdKernel.get_all(), "cutlassF", impl_file="", - autogen_dir=output_dir + autogen_dir=output_dir, ) write_decl_impl( BwdKernel.get_all(), "cutlassB", impl_file="", - autogen_dir=output_dir + autogen_dir=output_dir, ) if __name__ == "__main__": parser = argparse.ArgumentParser( - prog='generate_kernels', - description='Generate the mem-eff kernels template instantiations') + prog="generate_kernels", + description="Generate the mem-eff kernels template instantiations", + ) # Set an optional output directory - parser.add_argument('-o', '--output_dir', required=False, help="Where to generate the kernels " - " will default to ") + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="Where to generate the kernels " + " will default to ", + ) args = parser.parse_args() main(args.output_dir) diff --git a/aten/src/ATen/nnapi/codegen.py b/aten/src/ATen/nnapi/codegen.py index 76131ce1d70f79..3197d670092e28 100755 --- a/aten/src/ATen/nnapi/codegen.py +++ b/aten/src/ATen/nnapi/codegen.py @@ -7,9 +7,9 @@ we need with dlsym. We also generate a "check" wrapper that checks return values and throws C++ exceptions on errors. """ -import sys -import re import pathlib +import re +import sys import textwrap @@ -36,39 +36,155 @@ NNAPI_FUNCTIONS = [ ("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950 - ("int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device"), # noqa: B950 - ("int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name"), # noqa: B950 - ("int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version"), # noqa: B950 - ("int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel"), # noqa: B950 - ("int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps"), # noqa: B950 - ("int", "ANeuralNetworksCompilation_createForDevices", "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation"), # noqa: B950 - ("int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution"), # noqa: B950 - ("int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory"), # noqa: B950 - ("void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory"), # noqa: B950 - ("int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model"), # noqa: B950 + ( + "int", + "ANeuralNetworks_getDevice", + "uint32_t devIndex, ANeuralNetworksDevice** device", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksDevice_getName", + "const ANeuralNetworksDevice* device, const char** name", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksDevice_getVersion", + "const ANeuralNetworksDevice* device, const char** version", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksDevice_getFeatureLevel", + "const ANeuralNetworksDevice* device, int64_t* featureLevel", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_getSupportedOperationsForDevices", + " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksCompilation_createForDevices", + "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation", # noqa: B950 + ), + ( + "int", + "ANeuralNetworksExecution_compute", + "ANeuralNetworksExecution* execution", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksMemory_createFromFd", + "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory", + ), # noqa: B950 + ( + "void", + "ANeuralNetworksMemory_free", + "ANeuralNetworksMemory* memory", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_create", + "ANeuralNetworksModel** model", + ), # noqa: B950 ("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950 ("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950 - ("int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type"), # noqa: B950 - ("int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length"), # noqa: B950 - ("int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 - ("int", "ANeuralNetworksModel_addOperation", "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950 - ("int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950 - ("int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow"), # noqa: B950 - ("int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation"), # noqa: B950 - ("void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation"), # noqa: B950 - ("int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference"), # noqa: B950 - ("int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation"), # noqa: B950 - ("int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution"), # noqa: B950 - ("void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution"), # noqa: B950 - ("int", "ANeuralNetworksExecution_setInput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length"), # noqa: B950 - ("int", "ANeuralNetworksExecution_setInputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 - ("int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length"), # noqa: B950 - ("int", "ANeuralNetworksExecution_setOutputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 - ("int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event"), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_addOperand", + "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_setOperandValue", + "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_setOperandValueFromMemory", + "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_addOperation", + "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", # noqa: B950 + ), + ( + "int", + "ANeuralNetworksModel_identifyInputsAndOutputs", + "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksModel_relaxComputationFloat32toFloat16", + "ANeuralNetworksModel* model, bool allow", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksCompilation_create", + "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation", + ), # noqa: B950 + ( + "void", + "ANeuralNetworksCompilation_free", + "ANeuralNetworksCompilation* compilation", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksCompilation_setPreference", + "ANeuralNetworksCompilation* compilation, int32_t preference", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksCompilation_finish", + "ANeuralNetworksCompilation* compilation", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksExecution_create", + "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution", + ), # noqa: B950 + ( + "void", + "ANeuralNetworksExecution_free", + "ANeuralNetworksExecution* execution", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksExecution_setInput", + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length", # noqa: B950 + ), + ( + "int", + "ANeuralNetworksExecution_setInputFromMemory", + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950 + ), + ( + "int", + "ANeuralNetworksExecution_setOutput", + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksExecution_setOutputFromMemory", + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950 + ), + ( + "int", + "ANeuralNetworksExecution_startCompute", + "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event", + ), # noqa: B950 ("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950 ("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950 - ("int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank"), # noqa: B950 - ("int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions"), # noqa: B950 + ( + "int", + "ANeuralNetworksExecution_getOutputOperandRank", + "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank", + ), # noqa: B950 + ( + "int", + "ANeuralNetworksExecution_getOutputOperandDimensions", + "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions", + ), # noqa: B950 ] @@ -82,18 +198,26 @@ def main(argv): struct_members.append(f" {ret}(*{short_name})({args});") - load_functions.append(f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");') - load_functions.append(f' check_nnapi_.{short_name} = check_{short_name};') + load_functions.append( + f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");' + ) + load_functions.append(f" check_nnapi_.{short_name} = check_{short_name};") call_args = "".join(re.findall(r"\w+(?:,|$)", args)) if ret == "void": - define_checks.append(textwrap.dedent(f"""\ + define_checks.append( + textwrap.dedent( + f"""\ {ret} check_{short_name}({args}) {{ CAFFE_ENFORCE(nnapi_.{short_name}); nnapi_.{short_name}({call_args}); - }}""")) + }}""" + ) + ) if ret == "int": - define_checks.append(textwrap.dedent(f"""\ + define_checks.append( + textwrap.dedent( + f"""\ {ret} check_{short_name}({args}) {{ CAFFE_ENFORCE(nnapi_.{short_name}); int ret = nnapi_.{short_name}({call_args}); @@ -103,13 +227,16 @@ def main(argv): "{short_name}", "failed with error ", ret ); return ret; - }}""")) + }}""" + ) + ) out_dir = pathlib.Path(__file__).parent (out_dir / "nnapi_wrapper.h").write_text( - PREFIX + - textwrap.dedent("""\ + PREFIX + + textwrap.dedent( + """\ #ifndef NNAPI_WRAPPER_H_ #define NNAPI_WRAPPER_H_ #include @@ -122,13 +249,14 @@ def main(argv): void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi); #endif #endif - """) - .replace("__STRUCT_MEMBERS__", "\n".join(struct_members)) + """ + ).replace("__STRUCT_MEMBERS__", "\n".join(struct_members)) ) (out_dir / "nnapi_wrapper.cpp").write_text( - PREFIX + - textwrap.dedent("""\ + PREFIX + + textwrap.dedent( + """\ #ifndef _WIN32 #include #endif @@ -157,7 +285,8 @@ def main(argv): *check_nnapi = &check_nnapi_; #endif } - """) + """ + ) .replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks)) .replace("__LOAD_FUNCTIONS__", "\n".join(load_functions)) ) diff --git a/functorch/__init__.py b/functorch/__init__.py index 3ed8af04a997bb..aff35e592d80e5 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -5,18 +5,19 @@ # LICENSE file in the root directory of this source tree. import torch -# Top-level APIs. Please think carefully before adding something to the -# top-level namespace: -# - private helper functions should go into torch._functorch -# - very experimental things should go into functorch.experimental -# - compilation related things should go into functorch.compile - -# Was never documented -from torch._functorch.python_key import make_fx - from torch._functorch.deprecated import ( - vmap, grad, grad_and_value, vjp, jvp, jacrev, jacfwd, hessian, functionalize, - make_functional, make_functional_with_buffers, combine_state_for_ensemble, + combine_state_for_ensemble, + functionalize, + grad, + grad_and_value, + hessian, + jacfwd, + jacrev, + jvp, + make_functional, + make_functional_with_buffers, + vjp, + vmap, ) # utilities. Maybe these should go in their own namespace in the future? @@ -25,4 +26,13 @@ FunctionalModuleWithBuffers, ) +# Top-level APIs. Please think carefully before adding something to the +# top-level namespace: +# - private helper functions should go into torch._functorch +# - very experimental things should go into functorch.experimental +# - compilation related things should go into functorch.compile + +# Was never documented +from torch._functorch.python_key import make_fx + __version__ = torch.__version__ diff --git a/functorch/_src/eager_transforms/__init__.py b/functorch/_src/eager_transforms/__init__.py index e3e587c0978fad..6052b5548f4af3 100644 --- a/functorch/_src/eager_transforms/__init__.py +++ b/functorch/_src/eager_transforms/__init__.py @@ -2,6 +2,6 @@ # If you are not a PyTorch developer and you are relying on the following # imports, please file an issue. from torch._functorch.eager_transforms import ( - _unwrap_functional_tensor, _assert_wrapped_functional, + _unwrap_functional_tensor, ) diff --git a/functorch/_src/vmap/__init__.py b/functorch/_src/vmap/__init__.py index 792a2fde38bb35..dc90517753e50f 100644 --- a/functorch/_src/vmap/__init__.py +++ b/functorch/_src/vmap/__init__.py @@ -4,13 +4,13 @@ from torch._functorch.vmap import ( _add_batch_dim, _broadcast_to_and_flatten, + _create_batched_inputs, _get_name, + _process_batched_inputs, _remove_batch_dim, + _unwrap_batched, _validate_and_get_batch_size, Tensor, tree_flatten, tree_unflatten, - _process_batched_inputs, - _create_batched_inputs, - _unwrap_batched, ) diff --git a/functorch/benchmarks/chrome_trace_parser.py b/functorch/benchmarks/chrome_trace_parser.py index e03212f9af6ea2..ceb6ea58fbb915 100755 --- a/functorch/benchmarks/chrome_trace_parser.py +++ b/functorch/benchmarks/chrome_trace_parser.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 import argparse +import logging import os -import logging + import pandas as pd from torch._functorch.benchmark_utils import compute_utilization @@ -17,9 +18,10 @@ def get_model_name(filename): Get model name from a file in format {model_name}_chrome_trace_*.json """ _, tail = os.path.split(filename) - modelname = tail[:tail.find("_chrome_trace")] + modelname = tail[: tail.find("_chrome_trace")] return modelname + def get_total_length(run_times_df, modelname): return float(run_times_df[run_times_df["name"] == modelname]["runtime"]) @@ -31,14 +33,14 @@ def main(): "--runtime", "-runf", help="file name of the runtime file", required=True ) group.add_argument( - "--filename", "-f", action="append", help="a filename of the json file to process" - ) - group.add_argument( - "--folder", "-fd", help="a folder of the json files to process" + "--filename", + "-f", + action="append", + help="a filename of the json file to process", ) + group.add_argument("--folder", "-fd", help="a folder of the json files to process") args = parser.parse_args() - if args.filename: filenames = args.filename elif args.folder: @@ -58,11 +60,14 @@ def main(): try: modelname = get_model_name(filename) total_length = get_total_length(run_times_df, modelname) * 1e6 - utilization, mm_conv_utilization = compute_utilization(filenames, total_length) + utilization, mm_conv_utilization = compute_utilization( + filenames, total_length + ) print(f"{modelname}, {utilization}, {mm_conv_utilization}") except BaseException: logging.exception("%s, ERROR", filename) print(f"{filename}, ERROR") + if __name__ == "__main__": main() diff --git a/functorch/benchmarks/cse.py b/functorch/benchmarks/cse.py index 14cde14eb3085a..2cbb1411b7b9be 100644 --- a/functorch/benchmarks/cse.py +++ b/functorch/benchmarks/cse.py @@ -1,9 +1,10 @@ import torch import torch.fx as fx from functorch import make_fx -from torch.profiler import profile, ProfilerActivity from torch._functorch.compile_utils import fx_graph_cse +from torch.profiler import profile, ProfilerActivity + def profile_it(f, inp): for _ in range(5): @@ -20,6 +21,7 @@ def profile_it(f, inp): cuda_time_total = cuda_time_total + e.cuda_time_total return cuda_time_total / itr + def profile_function(name, f, inp): fx_g = make_fx(f)(inp) @@ -34,17 +36,23 @@ def profile_function(name, f, inp): avg_cuda_time_g = profile_it(new_g, inp) num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) - print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}") + print( + f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}" + ) + -g_gpu = torch.Generator(device='cuda') +g_gpu = torch.Generator(device="cuda") g_gpu.manual_seed(2147483647) -inp = torch.randn(2**20, device='cuda', generator=g_gpu) +inp = torch.randn(2**20, device="cuda", generator=g_gpu) + def f1(x): return x.cos().cos() + profile_function("f1", f1, inp) + def fsum(x): a = x.sum() b = x.sum() @@ -52,22 +60,29 @@ def fsum(x): d = x.sum() return a + b + c + d + profile_function("fsum", fsum, inp) + def fconcat(x): a = torch.cat((x, x)) b = torch.cat((x, x)) return a + b + + profile_function("fconcat", fconcat, inp) + def fsum2(x): a = x.sum() for _ in range(30): a = a + x.sum() return a + profile_function("fsum2", fsum2, inp) + def fsummulti(x): a = 0 for _ in range(3): @@ -75,8 +90,10 @@ def fsummulti(x): a = a * x.sum() return a + profile_function("fsummulti", fsummulti, inp) + def fsummulti2(x): a = 0 for _ in range(30): @@ -84,20 +101,25 @@ def fsummulti2(x): a = a * x.sum() return a + profile_function("fsummulti2", fsummulti2, inp) + def fcos(x): a = 0 for _ in range(3): a = a + x.cos() return a + profile_function("fcos", fcos, inp) + def fcos2(x): a = 0 for _ in range(30): a = a + x.cos() return a + profile_function("fcos2", fcos2, inp) diff --git a/functorch/benchmarks/operator_authoring.py b/functorch/benchmarks/operator_authoring.py index 456f5040d759f2..065c64297a0882 100644 --- a/functorch/benchmarks/operator_authoring.py +++ b/functorch/benchmarks/operator_authoring.py @@ -1,7 +1,8 @@ +import timeit from functools import partial + import numpy as np import pandas as pd -import timeit import torch from functorch.compile import pointwise_operator diff --git a/functorch/benchmarks/per_sample_grads.py b/functorch/benchmarks/per_sample_grads.py index e9e3524eca53be..3e4e032160caf6 100644 --- a/functorch/benchmarks/per_sample_grads.py +++ b/functorch/benchmarks/per_sample_grads.py @@ -1,14 +1,14 @@ +import time + import torch import torch.nn as nn import torchvision.models as models -from opacus.utils.module_modification import convert_batchnorm_modules -import time -from functorch import vmap, grad -from functorch import make_functional +from functorch import grad, make_functional, vmap from opacus import PrivacyEngine +from opacus.utils.module_modification import convert_batchnorm_modules -device = 'cuda' +device = "cuda" batch_size = 128 torch.manual_seed(0) @@ -20,6 +20,7 @@ targets = torch.randint(0, 10, (batch_size,), device=device) func_model, weights = make_functional(model_functorch) + def compute_loss(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) @@ -27,11 +28,11 @@ def compute_loss(weights, image, target): loss = criterion(output, targets) return loss + def functorch_per_sample_grad(): compute_grad = grad(compute_loss) compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) - start = time.time() result = compute_per_sample_grad(weights, images, targets) torch.cuda.synchronize() @@ -39,6 +40,7 @@ def functorch_per_sample_grad(): return result, end - start # end - start in seconds + torch.manual_seed(0) model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10)) model_opacus = model_opacus.to(device) @@ -54,6 +56,7 @@ def functorch_per_sample_grad(): max_grad_norm=10000.0, ) + def opacus_per_sample_grad(): start = time.time() output = model_opacus(images) @@ -63,7 +66,7 @@ def opacus_per_sample_grad(): end = time.time() expected = [p.grad_sample for p in model_opacus.parameters()] for p in model_opacus.parameters(): - delattr(p, 'grad_sample') + delattr(p, "grad_sample") p.grad = None return expected, end - start diff --git a/functorch/benchmarks/pointwise_scorecard.py b/functorch/benchmarks/pointwise_scorecard.py index 15863dc3510cfa..6b3250cc9ec46a 100644 --- a/functorch/benchmarks/pointwise_scorecard.py +++ b/functorch/benchmarks/pointwise_scorecard.py @@ -1,14 +1,16 @@ +import inspect +import itertools import sys import time + import torch -import inspect -import itertools from functorch import pointwise_operator torch.set_num_threads(1) torch._C._debug_set_fusion_group_inlining(False) + def rand(*shape): return torch.rand(*shape).mul(16).add(1) @@ -19,105 +21,139 @@ def rand(*shape): def scalar(): return (rand(1), rand(1)) + def small(): return (rand(32), rand(32)) + def small_2d(): return (rand(1, 32), rand(1, 32)) + def small_broadcast(): return (rand(4, 32), rand(32)) + def medium(): return (rand(32, 12, 64, 64), rand(32, 12, 64, 64)) + def medium_sliced(): - return (rand(32, 12, 64, 64)[..., ::2], - rand(32, 12, 64, 64)[..., ::2]) + return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2]) + def medium_transpose(): - return (rand(32, 12, 64, 64).transpose(-1, -2), - rand(32, 12, 64, 64).transpose(-1, -2)) + return ( + rand(32, 12, 64, 64).transpose(-1, -2), + rand(32, 12, 64, 64).transpose(-1, -2), + ) + def medium2(): return (rand(32, 3, 224, 224), rand(32, 3, 224, 224)) + def medium3d(): return (rand(16, 32, 64), rand(16, 32, 64)) + def medium_channels_last(): - return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last), - rand(32, 3, 224, 224).to(memory_format=torch.channels_last)) + return ( + rand(32, 3, 224, 224).to(memory_format=torch.channels_last), + rand(32, 3, 224, 224).to(memory_format=torch.channels_last), + ) + def medium_broadcast(): return (rand(32, 12, 64, 64), rand(64)) + def medium_broadcast_channels_last(): - return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), - rand(3, 1, 1)) + return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1)) + def large(): return (rand(8192, 8192), rand(8192, 8192)) + def large_transpose(): - return (rand(8192, 8192).transpose(0, 1), - rand(8192, 8192).transpose(0, 1)) + return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1)) + def large_channels_last(): - return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last), - rand(32, 32, 256, 256).to(memory_format=torch.channels_last)) + return ( + rand(32, 32, 256, 256).to(memory_format=torch.channels_last), + rand(32, 32, 256, 256).to(memory_format=torch.channels_last), + ) + def pathological_broadcast(): return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2)) + # ------------------------------------------------------------------------------ # Operator test cases # ------------------------------------------------------------------------------ def add(a, b): return a + b + def sub(a, b): return a - b + def mul(a, b): return a * b + def div(a, b): return a / b + def relu(a): return a.relu() + def sigmoid(a): return a.sigmoid() + def tanh(a): return a.tanh() + def log(a): return a.log() + def exp(a): return a.exp() + def square(a): - return a ** 2 + return a**2 + def fma(a, b): return a * b + b + def hardswish(a): return a * (a + 3.0).clamp(0.0, 6.0) / 6.0 + def native_hardswish(a): return torch._C._nn.hardswish(a) + def softplus(a): return (a * 1.0).exp().log1p() / 1.0 + def mish(a): return a * ((a * 1.0).exp().log1p() / 1.0).tanh() + # ------------------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------------------ @@ -128,6 +164,7 @@ def time_cpu(fn, args, iters): e = time.perf_counter() return e - s + def time_cuda(fn, args, iters): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -138,19 +175,23 @@ def time_cuda(fn, args, iters): torch.cuda.synchronize() return start.elapsed_time(end) / 1e3 + def benchmark_with_timer(fn, args, timer): timer(fn, args, 3) calibration = timer(fn, args, 1) iters = int(1.0 / calibration) return timer(fn, args, iters) / iters + def benchmark(fn, args): timer = time_cpu if args[0].device.type == "cpu" else time_cuda return benchmark_with_timer(fn, args, timer) + def micros(s): return f"{s * 1e6:.1f}" + shapes = [ scalar, small, @@ -211,7 +252,17 @@ def micros(s): args = shape()[:nargs] result = benchmark(operator, args) - print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) + print( + ",".join( + [ + "eager", + args[0].device.type, + operator.__name__, + shape.__name__, + micros(result), + ] + ) + ) try: if shape == medium_transpose: raise RuntimeError("pointwise_operator hangs on medium_transpose") @@ -219,11 +270,41 @@ def micros(s): raise RuntimeError("pointwise_operator fails on medium_transpose") pw_op = pointwise_operator(operator) result = benchmark(pw_op, args) - print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) + print( + ",".join( + [ + "pointwise", + args[0].device.type, + operator.__name__, + shape.__name__, + micros(result), + ] + ) + ) except Exception: - print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))])) + print( + ",".join( + [ + "pointwise", + args[0].device.type, + operator.__name__, + shape.__name__, + micros(float("nan")), + ] + ) + ) ts_op = torch.jit.script(operator) result = benchmark(ts_op, args) - print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) + print( + ",".join( + [ + "fuser", + args[0].device.type, + operator.__name__, + shape.__name__, + micros(result), + ] + ) + ) sys.stdout.flush() diff --git a/functorch/benchmarks/process_scorecard.py b/functorch/benchmarks/process_scorecard.py index f95d879238a122..e535dcb5b5aa27 100644 --- a/functorch/benchmarks/process_scorecard.py +++ b/functorch/benchmarks/process_scorecard.py @@ -1,11 +1,13 @@ -import pandas import matplotlib.pyplot as plt +import pandas df = pandas.read_csv("perf.csv") ops = pandas.unique(df["operator"]) nops = len(ops) -pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"]) +pivot_op_shape = df.pivot_table( + values="time", index=["operator", "shape"], columns=["fuser"] +) pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T plt.rcParams["figure.figsize"] = (20, 100) diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py index 569c1b6819bddc..96b853cd2e27e9 100644 --- a/functorch/compile/__init__.py +++ b/functorch/compile/__init__.py @@ -1,31 +1,31 @@ -from torch._functorch.python_key import pythonkey_decompose -from torch._functorch.fx_minifier import minifier +from torch._functorch import config from torch._functorch.aot_autograd import ( aot_function, aot_module, + aot_module_simplified, compiled_function, compiled_module, - aot_module_simplified, - get_graph_being_compiled, - get_aot_graph_name, get_aot_compilation_context, + get_aot_graph_name, + get_graph_being_compiled, + make_boxed_compiler, make_boxed_func, - make_boxed_compiler ) from torch._functorch.compilers import ( - ts_compile, + debug_compile, + default_decompositions, draw_graph_compile, - nop, - nnc_jit, memory_efficient_fusion, - debug_compile, + nnc_jit, + nop, print_compile, - default_decompositions + ts_compile, ) +from torch._functorch.fx_minifier import minifier from torch._functorch.partitioners import ( - min_cut_rematerialization_partition, default_partition, draw_graph, draw_joint_graph, + min_cut_rematerialization_partition, ) -from torch._functorch import config +from torch._functorch.python_key import pythonkey_decompose diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index 642d866903b34e..e8c6b0df0d5801 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -1,20 +1,26 @@ -import torch -from typing import Union, Sequence -import inspect import dis -from .tree_map import tree_flatten, tree_map -from .wrap_type import wrap_type +import inspect +from typing import Sequence, Union + +import torch + import functorch._C from functorch._C import dim as _C +from .tree_map import tree_flatten, tree_map +from .wrap_type import wrap_type + _C._patch_tensor_class() dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists + class DimensionMismatchError(Exception): pass + class DimensionBindError(Exception): pass + from . import op_properties # use dict to avoid writing C++ bindings for set @@ -24,11 +30,11 @@ class DimensionBindError(Exception): if not use_c: from . import reference + class _Tensor: # fast path around slow wrapping/unwrapping logic for simply queries used # by the implementation... - @property def dims(self): return tuple(d for d in self._levels if isinstance(d, Dim)) @@ -47,11 +53,12 @@ def dim(self): def __repr__(self): tensor, levels, ndim = self._tensor, self._levels, self.ndim - return f'{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}' + return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}" TensorLike = (_Tensor, torch.Tensor) + class Dim(_C.Dim, _Tensor): # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence. # Tensor defines format, but we want to print Dims with special formatting @@ -69,6 +76,7 @@ def cat(tensors, dim, new_dim): n = dims() return stack(tensors, n, dim).index([n, dim], new_dim) + if use_c: _wrap = _C._wrap @@ -107,41 +115,41 @@ def _def(name, *args, **kwargs): else: _Tensor.order = reference.positional -_def('mean') -_def('sum') -_def('all') -_def('amax') -_def('amin') -_def('aminmax') -_def('any') -_def('count_nonzero') -_def('logsumexp') -_def('nanmean') -_def('nansum') -_def('prod') -_def('std', keepdim_offset=2) -_def('var', keepdim_offset=2) -_def('max', single_dim=True) -_def('min', single_dim=True) -_def('argmax', single_dim=True) -_def('argmin', single_dim=True) -_def('kthvalue', single_dim=True) -_def('median', single_dim=True) -_def('nanmedian', single_dim=True) -_def('mode', single_dim=True) -_def('sort', reduce=False) -_def('argsort', reduce=False) -_def('unbind', single_dim=True) -_def('chunk', dim_offset=1, reduce=False) -_def('cummax', single_dim=True, reduce=False) -_def('cummin', single_dim=True, reduce=False) -_def('cumprod', single_dim=True, reduce=False) -_def('cumprod_', single_dim=True, reduce=False) -_def('cumsum', single_dim=True, reduce=False) -_def('cumsum_', single_dim=True, reduce=False) -_def('logcumsumexp', single_dim=True, reduce=False) -_def('renorm', dim_offset=1, single_dim=True, reduce=False) -_def('softmax', single_dim=True, reduce=False) +_def("mean") +_def("sum") +_def("all") +_def("amax") +_def("amin") +_def("aminmax") +_def("any") +_def("count_nonzero") +_def("logsumexp") +_def("nanmean") +_def("nansum") +_def("prod") +_def("std", keepdim_offset=2) +_def("var", keepdim_offset=2) +_def("max", single_dim=True) +_def("min", single_dim=True) +_def("argmax", single_dim=True) +_def("argmin", single_dim=True) +_def("kthvalue", single_dim=True) +_def("median", single_dim=True) +_def("nanmedian", single_dim=True) +_def("mode", single_dim=True) +_def("sort", reduce=False) +_def("argsort", reduce=False) +_def("unbind", single_dim=True) +_def("chunk", dim_offset=1, reduce=False) +_def("cummax", single_dim=True, reduce=False) +_def("cummin", single_dim=True, reduce=False) +_def("cumprod", single_dim=True, reduce=False) +_def("cumprod_", single_dim=True, reduce=False) +_def("cumsum", single_dim=True, reduce=False) +_def("cumsum_", single_dim=True, reduce=False) +_def("logcumsumexp", single_dim=True, reduce=False) +_def("renorm", dim_offset=1, single_dim=True, reduce=False) +_def("softmax", single_dim=True, reduce=False) softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False) # stuff to handle in the future, because they require special diff --git a/functorch/dim/batch_tensor.py b/functorch/dim/batch_tensor.py index f8b03648881433..0fc17f2492d534 100644 --- a/functorch/dim/batch_tensor.py +++ b/functorch/dim/batch_tensor.py @@ -3,14 +3,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch._C._functorch import ( - _vmap_add_layers, - _vmap_remove_layers, -) - from contextlib import contextmanager +from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers + _enabled = False + + @contextmanager def _enable_layers(dims): global _enabled diff --git a/functorch/dim/delayed_mul_tensor.py b/functorch/dim/delayed_mul_tensor.py index 92082bb3fa6241..3984a063885907 100644 --- a/functorch/dim/delayed_mul_tensor.py +++ b/functorch/dim/delayed_mul_tensor.py @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch + from . import _Tensor, Tensor from .reference import _dims, _enable_layers, llist, ltuple + class DelayedMulTensor(_Tensor): def __init__(self, lhs, rhs): self._lhs, self._rhs = lhs, rhs @@ -37,7 +39,9 @@ def _batchtensor(self): @property def _tensor(self): if self._tensor_data is None: - self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor + self._tensor_data = Tensor.from_batched( + self._batchtensor, self._has_device + )._tensor return self._tensor_data @property @@ -48,20 +52,26 @@ def ndim(self): def dims(self): return ltuple(super().dims) - def sum(self, dim): dims = _dims(dim, 0, False, False) - n = ord('a') + n = ord("a") all_levels = self._levels def to_char(d): return chr(n + all_levels.index(d)) + plhs, levelslhs = self._lhs._tensor, self._lhs._levels prhs, levelsrhs = self._rhs._tensor, self._rhs._levels new_dims = tuple(d for d in self.dims if d not in dims) new_levels = [l for l in self._levels if l not in dims] - fmt = ''.join([*(to_char(d) for d in levelslhs), ',', - *(to_char(d) for d in levelsrhs), '->', - *(to_char(d) for d in new_levels)]) + fmt = "".join( + [ + *(to_char(d) for d in levelslhs), + ",", + *(to_char(d) for d in levelsrhs), + "->", + *(to_char(d) for d in new_levels), + ] + ) result_data = torch.einsum(fmt, (plhs, prhs)) return Tensor.from_positional(result_data, new_levels, True) diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py index 312059a4c9f15a..c1a93ebe5c93af 100644 --- a/functorch/dim/dim.py +++ b/functorch/dim/dim.py @@ -4,11 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. _vmap_levels = [] + + @dataclass class LevelInfo: level: int alive: bool = True + class Dim: def __init__(self, name: str, size: Union[None, int] = None): self.name = name @@ -20,7 +23,9 @@ def __init__(self, name: str, size: Union[None, int] = None): def __del__(self): if self._vmap_level is not None: _vmap_active_levels[self._vmap_stack].alive = False - while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level: + while ( + not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level + ): _vmap_decrement_nesting() _vmap_levels.pop() @@ -33,13 +38,14 @@ def size(self): def size(self, size: int): if self._size is None: self._size = size - self._vmap_level = _vmap_increment_nesting(size, 'same') + self._vmap_level = _vmap_increment_nesting(size, "same") self._vmap_stack = len(_vmap_levels) _vmap_levels.append(LevelInfo(self._vmap_level)) elif self._size != size: raise DimensionBindError( - f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}") + f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" + ) @property def is_bound(self): @@ -50,10 +56,13 @@ def __repr__(self): def extract_name(inst): - assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME' + assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" return inst.argval + _cache = {} + + def dims(lists=0): frame = inspect.currentframe() assert frame is not None @@ -66,17 +75,22 @@ def dims(lists=0): instructions = list(dis.get_instructions(calling_frame.f_code)) unpack = instructions[first] - if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME': + if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": # just a single dim, not a list name = unpack.argval ctor = Dim if lists == 0 else DimList _cache[key] = lambda: ctor(name=name) else: - assert unpack.opname == 'UNPACK_SEQUENCE' + assert unpack.opname == "UNPACK_SEQUENCE" ndims = unpack.argval - names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims)) + names = tuple( + extract_name(instructions[first + 1 + i]) for i in range(ndims) + ) first_list = len(names) - lists - _cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names)) + _cache[key] = lambda: tuple( + Dim(n) if i < first_list else DimList(name=n) + for i, n in enumerate(names) + ) return _cache[key]() @@ -87,6 +101,7 @@ def convert(a): else: assert isinstance(a, int) return positional[a] + if arg is None: return positional elif not isinstance(arg, (Dim, int)): diff --git a/functorch/dim/magic_trace.py b/functorch/dim/magic_trace.py index 8d4e5ec31ef897..5c962a898ca79c 100644 --- a/functorch/dim/magic_trace.py +++ b/functorch/dim/magic_trace.py @@ -3,25 +3,33 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from contextlib import contextmanager import os -import subprocess import signal +import subprocess +from contextlib import contextmanager + @contextmanager -def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'): +def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"): pid = os.getpid() if not os.path.exists(magic_trace_cache): print(f"Downloading magic_trace to: {magic_trace_cache}") - subprocess.run(['wget', '-O', magic_trace_cache, '-q', - 'https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace']) - subprocess.run(['chmod', '+x', magic_trace_cache]) - args = [magic_trace_cache, 'attach', '-pid', str(pid), '-o', output] - p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding='utf-8') + subprocess.run( + [ + "wget", + "-O", + magic_trace_cache, + "-q", + "https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace", + ] + ) + subprocess.run(["chmod", "+x", magic_trace_cache]) + args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output] + p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8") while True: x = p.stderr.readline() print(x) - if 'Attached' in x: + if "Attached" in x: break try: yield @@ -31,4 +39,4 @@ def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'): print(p.stderr.read()) p.stderr.close() if r != 0: - raise ValueError(f'magic_trace exited abnormally: {r}') + raise ValueError(f"magic_trace exited abnormally: {r}") diff --git a/functorch/dim/op_properties.py b/functorch/dim/op_properties.py index fdfb0b9ae91d32..3760f2cb0ea79c 100644 --- a/functorch/dim/op_properties.py +++ b/functorch/dim/op_properties.py @@ -4,29 +4,58 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch + # pointwise operators can go through a faster pathway -tensor_magic_methods = [ - 'add', - '' -] +tensor_magic_methods = ["add", ""] pointwise_magic_methods_with_reverse = ( - 'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod', - 'pow', 'lshift', 'rshift', 'and', 'or', 'xor' + "add", + "sub", + "mul", + "floordiv", + "div", + "truediv", + "mod", + "pow", + "lshift", + "rshift", + "and", + "or", + "xor", ) pointwise_magic_methods = ( - *(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)), - 'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos', - 'abs', 'invert', - 'iadd', 'isub', 'imul', 'ifloordiv', 'idiv', - 'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand', - 'ior', 'ixor', - 'int', 'long', 'float', 'complex', + *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)), + "eq", + "gt", + "le", + "lt", + "ge", + "gt", + "ne", + "neg", + "pos", + "abs", + "invert", + "iadd", + "isub", + "imul", + "ifloordiv", + "idiv", + "itruediv", + "imod", + "ipow", + "ilshift", + "irshift", + "iand", + "ior", + "ixor", + "int", + "long", + "float", + "complex", ) -pointwise_methods = ( - *(f'__{m}__' for m in pointwise_magic_methods), -) +pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),) pointwise = ( *(getattr(torch.Tensor, m) for m in pointwise_methods), diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py index ee351199c974a2..29d65b13ab3aac 100644 --- a/functorch/dim/reference.py +++ b/functorch/dim/reference.py @@ -6,23 +6,28 @@ # reference python implementations for C ops import torch -from .tree_map import tree_flatten, tree_map -from .batch_tensor import _enable_layers -from . import op_properties + from functorch._C import dim as _C +from . import op_properties +from .batch_tensor import _enable_layers +from .tree_map import tree_flatten, tree_map + DimList = _C.DimList -from functools import reduce import operator +from functools import reduce # use dict to avoid writing C++ bindings for set pointwise = set(op_properties.pointwise) + + def prod(x): return reduce(operator.mul, x, 1) def _wrap_dim(d, N, keepdim): from . import Dim + if isinstance(d, Dim): assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" return d @@ -31,40 +36,52 @@ def _wrap_dim(d, N, keepdim): else: return d + def _dims(d, N, keepdim, single_dim): from . import Dim + if isinstance(d, (Dim, int)): return ltuple((_wrap_dim(d, N, keepdim),)) assert not single_dim, f"expected a single dimension or int but found: {d}" return ltuple(_wrap_dim(x, N, keepdim) for x in d) + def _bind_dims_to_size(lhs_size, rhs, lhs_debug): from . import DimensionMismatchError + not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) if len(not_bound) == 1: idx, d = not_bound[0] rhs_so_far = prod(r.size for r in rhs if r.is_bound) if lhs_size % rhs_so_far != 0: - rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}") + rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) + raise DimensionMismatchError( + f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" + ) new_size = lhs_size // rhs_so_far d.size = new_size elif len(not_bound) > 1: - rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}") + rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) + raise DimensionMismatchError( + f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" + ) else: rhs_size = prod(r.size for r in rhs) if lhs_size != rhs_size: raise DimensionMismatchError( - f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}") + f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" + ) + def _tensor_levels(inp): from . import _Tensor + if isinstance(inp, _Tensor): return inp._tensor, llist(inp._levels), inp._has_device else: return inp, llist(range(-inp.ndim, 0)), True + def _match_levels(v, from_levels, to_levels): view = [] permute = [] @@ -90,6 +107,7 @@ def _match_levels(v, from_levels, to_levels): # should not physically move if possible def _positional_no_permute(self, dim, expand_dim=False): from . import Tensor + ptensor, levels = self._tensor, llist(self._levels) try: idx = levels.index(dim) @@ -107,8 +125,10 @@ def _positional_no_permute(self, dim, expand_dim=False): levels[idx] = -idx_batched - 1 return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched + def seq(a, b): from . import Dim + if isinstance(a, Dim) != isinstance(b, Dim): return False if isinstance(a, Dim): @@ -116,6 +136,7 @@ def seq(a, b): else: return a == b + class isin: def __contains__(self, item): for x in self: @@ -133,18 +154,27 @@ def index(self, item): class llist(isin, list): pass + class ltuple(isin, tuple): pass + empty_dict = {} + + @classmethod def __torch_function__(self, orig, cls, args, kwargs=empty_dict): - from . import _Tensor, TensorLike, Tensor + from . import _Tensor, Tensor, TensorLike from .delayed_mul_tensor import DelayedMulTensor if orig is torch.Tensor.__mul__: lhs, rhs = args - if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0: + if ( + isinstance(lhs, _Tensor) + and isinstance(rhs, _Tensor) + and lhs.ndim == 0 + and rhs.ndim == 0 + ): return DelayedMulTensor(lhs, rhs) all_dims = llist() flat_args, unflatten = tree_flatten((args, kwargs)) @@ -172,7 +202,11 @@ def unwrap(t): for i, f in enumerate(flat_args): if isinstance(f, TensorLike): ptensor, levels, _ = _tensor_levels(f) - if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None: + if ( + isinstance(f, _Tensor) + and not f._has_device + and device_holding_tensor is not None + ): ptensor = ptensor.to(device=device_holding_tensor.device) flat_args[i] = ptensor for l in levels: @@ -187,14 +221,19 @@ def unwrap(t): def wrap(t): if isinstance(t, TensorLike): - return Tensor.from_positional(t, result_levels, device_holding_tensor is not None) + return Tensor.from_positional( + t, result_levels, device_holding_tensor is not None + ) return t + return tree_map(wrap, result) else: + def wrap(t): if isinstance(t, TensorLike): return Tensor.from_batched(t, device_holding_tensor is not None) return t + with _enable_layers(all_dims): print(f"batch_tensor for {orig}") args, kwargs = unflatten(unwrap(f) for f in flat_args) @@ -202,8 +241,10 @@ def wrap(t): # print("END", orig) return tree_map(wrap, result) + def positional(self, *dims): from . import Dim, Tensor + ptensor, levels = self._tensor, llist(self._levels) flat_dims = llist() view = [] @@ -231,7 +272,9 @@ def positional(self, *dims): try: idx = levels.index(d) except ValueError as e: - raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e + raise DimensionBindError( + f"tensor of dimensions {self.dims} does not contain dim {d}" + ) from e p = permute[idx] del levels[idx] del permute[idx] @@ -245,15 +288,18 @@ def positional(self, *dims): levels[i] = -seen result = Tensor.from_positional(ptensor, levels, self._has_device) if needs_view: - result = result.reshape(*view, *result.size()[len(flat_dims):]) + result = result.reshape(*view, *result.size()[len(flat_dims) :]) return result + def _contains_dim(input): from . import Dim + for i in input: if isinstance(i, Dim): return True + def expand(self, *sizes): if not _contains_dim(sizes): return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) @@ -265,27 +311,36 @@ def expand(self, *sizes): _not_present = object() + def _getarg(name, offset, args, kwargs, default): if len(args) > offset: return args[offset] return kwargs.get(name, default) + def _patcharg(name, offset, args, kwargs, value): if len(args) > offset: args[offset] = value else: kwargs[name] = value -def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True): - from . import TensorLike, Dim, Tensor + +def _wrap( + orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True +): + from . import Dim, Tensor, TensorLike def fn(self, *args, **kwargs): dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) if dim is _not_present or (single_dim and not isinstance(dim, Dim)): with _enable_layers(self.dims): print(f"dim fallback batch_tensor for {orig}") - return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device) - keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False + return Tensor.from_batched( + orig(self._batchtensor, *args, **kwargs), self._has_device + ) + keepdim = ( + _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False + ) t, levels = self._tensor, llist(self._levels) dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) dim_indices = tuple(levels.index(d) for d in dims) @@ -295,7 +350,9 @@ def fn(self, *args, **kwargs): new_levels = levels if len(dim_indices) == 1: - dim_indices = dim_indices[0] # so that dims that really only take a single argument work... + dim_indices = dim_indices[ + 0 + ] # so that dims that really only take a single argument work... args = list(args) _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) @@ -303,21 +360,27 @@ def wrap(t): if isinstance(t, TensorLike): return Tensor.from_positional(t, new_levels, self._has_device) return t + with _enable_layers(new_levels): print(f"dim used batch_tensor for {orig}") r = orig(t, *args, **kwargs) return tree_map(wrap, r) + return fn + def _def(name, *args, **kwargs): from . import _Tensor + orig = getattr(torch.Tensor, name) setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) + no_slice = slice(None) _orig_getitem = torch.Tensor.__getitem__ + class dim_tracker: def __init__(self): self.dims = llist() @@ -331,8 +394,10 @@ def record(self, d): def __getitem__(self, d): return self.count[self.dims.index(d)] + def t__getitem__(self, input): - from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor + from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike + # * bail to original example if we have a single non-Dim tensor, or a non-tensor # * locate ... or an unbound tensor list, and determine its size, bind dim list # (remember that None does not count to the total dim count) @@ -345,10 +410,13 @@ def t__getitem__(self, input): # this handles bool indexing handling, as well as some other simple cases. - is_simple = (not isinstance(input, Dim) and - not isinstance(input, (tuple, list)) and - # WAR for functorch bug where zero time tensors in getitem are not handled correctly. - not (isinstance(input, TensorLike) and input.ndim == 0)) + is_simple = ( + not isinstance(input, Dim) + and not isinstance(input, (tuple, list)) + and + # WAR for functorch bug where zero time tensors in getitem are not handled correctly. + not (isinstance(input, TensorLike) and input.ndim == 0) + ) if is_simple: if isinstance(self, _Tensor): @@ -368,8 +436,10 @@ def t__getitem__(self, input): for i, s in enumerate(input): if s is ... or isinstance(s, DimList) and not s.is_bound: if expanding_object is not None: - msg = 'at most one ... or unbound dimension list can exist in indexing list but' \ - f' found 2 at offsets {i} and {expanding_object}' + msg = ( + "at most one ... or unbound dimension list can exist in indexing list but" + f" found 2 at offsets {i} and {expanding_object}" + ) raise DimensionBindError(msg) expanding_object = i @@ -381,17 +451,21 @@ def t__getitem__(self, input): ndim = self.ndim if dims_indexed > ndim: - raise IndexError(f'at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions.') + raise IndexError( + f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." + ) if expanding_object is not None: expanding_ndims = ndim - dims_indexed obj = input[expanding_object] if obj is ...: - input[expanding_object:expanding_object + 1] = [no_slice] * expanding_ndims + input[expanding_object : expanding_object + 1] = [ + no_slice + ] * expanding_ndims else: obj.bind_len(expanding_ndims) # flatten the dimslists into the indexing for i in reversed(dimlists): - input[i:i + 1] = input[i] + input[i : i + 1] = input[i] dims_indexed = 0 requires_view = False size = self.size() @@ -420,7 +494,7 @@ def add_dims(t): elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): for d in idx: dims_seen.record(idx) - _bind_dims_to_size(sz, idx, f'offset {i}') + _bind_dims_to_size(sz, idx, f"offset {i}") view_sizes.extend(d.size for d in idx) requires_view = True dim_packs.append(i) @@ -431,7 +505,7 @@ def add_dims(t): if requires_view: self = self.view(*view_sizes) for i in reversed(dim_packs): - input[i:i + 1] = input[i] + input[i : i + 1] = input[i] # currenty: # input is flat, containing either Dim, or Tensor, or something valid for standard indexing @@ -499,6 +573,7 @@ def add_dims(t): return Tensor.from_positional(result, result_levels, has_device) + # XXX - dim is optional and can be the outer-most dimension... def stack(tensors, new_dim, dim=0, out=None): if isinstance(dim, int): @@ -517,12 +592,20 @@ def stack(tensors, new_dim, dim=0, out=None): pr = torch.stack(ptensors, index, out=out) return pr.index((index, index + 1), (new_dim, dim)) + _orig_split = torch.Tensor.split + + def split(self, split_size_or_sections, dim=0): - from . import Dim, _Tensor - if isinstance(split_size_or_sections, int) or any(isinstance(t, int) for t in split_size_or_sections): + from . import _Tensor, Dim + + if isinstance(split_size_or_sections, int) or any( + isinstance(t, int) for t in split_size_or_sections + ): if isinstance(dim, Dim): - raise ValueError('when dim is specified as a Dim object, split sizes must also be dimensions.') + raise ValueError( + "when dim is specified as a Dim object, split sizes must also be dimensions." + ) return _orig_split(self, split_size_or_sections, dim=dim) if isinstance(dim, Dim): @@ -542,8 +625,9 @@ def split(self, split_size_or_sections, dim=0): unbound.append(i) if unbound: - assert total_bound_size <= size, \ - f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" + assert ( + total_bound_size <= size + ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" remaining_size = size - total_bound_size chunk_size = -(-remaining_size // len(unbound)) for u in unbound: @@ -552,6 +636,10 @@ def split(self, split_size_or_sections, dim=0): sizes[u] = sz remaining_size -= sz else: - assert total_bound_size == size, \ - f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" - return tuple(t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))) + assert ( + total_bound_size == size + ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" + return tuple( + t.index(dim, d) + for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) + ) diff --git a/functorch/dim/tree_map.py b/functorch/dim/tree_map.py index 89aaad09eb3306..1f02f02656f288 100644 --- a/functorch/dim/tree_map.py +++ b/functorch/dim/tree_map.py @@ -5,8 +5,10 @@ # LICENSE file in the root directory of this source tree. from functorch._C import dim + tree_flatten = dim.tree_flatten + def tree_map(fn, tree): vs, unflatten = tree_flatten(tree) return unflatten(fn(v) for v in vs) diff --git a/functorch/dim/wrap_type.py b/functorch/dim/wrap_type.py index 8212836d3d6ae7..e2146c4a21a144 100644 --- a/functorch/dim/wrap_type.py +++ b/functorch/dim/wrap_type.py @@ -4,22 +4,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType +from types import ( + BuiltinMethodType, + FunctionType, + GetSetDescriptorType, + MethodDescriptorType, + WrapperDescriptorType, +) + from functorch._C import dim as _C + _wrap_method = _C._wrap_method -FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType) +FUNC_TYPES = ( + FunctionType, + MethodDescriptorType, + BuiltinMethodType, + WrapperDescriptorType, +) PROPERTY_TYPES = (GetSetDescriptorType, property) + def _py_wrap_method(orig, __torch_function__): def impl(*args, **kwargs): return __torch_function__(orig, None, args, kwargs) - return impl + return impl def wrap_type(use_c, to_patch, pattern, __torch_function__): - if use_c: wrap_method = _wrap_method else: @@ -29,18 +42,27 @@ def wrap_type(use_c, to_patch, pattern, __torch_function__): for t in reversed(pattern.mro()[:-1]): # skip object all.update(t.__dict__) - def wrap_attr(orig): return property(wrap_method(orig.__get__, __torch_function__)) - for name, obj in all.items(): - if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'): + if name in ( + "__dict__", + "__new__", + "__init__", + "__repr__", + "__weakref__", + "__doc__", + "__module__", + "__dir__", + ): continue # skip things that have been overloaded # things that come from object like `__eq__` still need to be patched, however. - if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None): + if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr( + object, name, None + ): continue if isinstance(obj, FUNC_TYPES): diff --git a/functorch/docs/source/conf.py b/functorch/docs/source/conf.py index 097482abda2d61..68f02c8b81094c 100644 --- a/functorch/docs/source/conf.py +++ b/functorch/docs/source/conf.py @@ -14,19 +14,22 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import os + +import functorch + # import sys # source code directory, relative to this file, for sphinx-autobuild # sys.path.insert(0, os.path.abspath('../..')) import torch -import functorch -RELEASE = os.environ.get('RELEASE', False) +RELEASE = os.environ.get("RELEASE", False) -import pytorch_sphinx_theme import sys +import pytorch_sphinx_theme + # -- General configuration ------------------------------------------------ # Required version of sphinx is set from docs/requirements.txt @@ -35,18 +38,18 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", # 'sphinxcontrib.katex', - 'sphinx.ext.autosectionlabel', - 'sphinx_copybutton', - 'myst_nb', + "sphinx.ext.autosectionlabel", + "sphinx_copybutton", + "myst_nb", ] # sys.path.insert(0, os.path.abspath('./notebooks')) @@ -75,21 +78,21 @@ autosummary_generate = True # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'functorch' -copyright = 'PyTorch Contributors' -author = 'PyTorch Contributors' +project = "functorch" +copyright = "PyTorch Contributors" +author = "PyTorch Contributors" functorch_version = str(functorch.__version__) # The version info for the project you're documenting, acts as replacement for @@ -98,16 +101,16 @@ # # The short X.Y version. # TODO: change to [:2] at v1.0 -version = 'nightly (' + functorch_version + ')' +version = "nightly (" + functorch_version + ")" # The full version, including alpha/beta/rc tags. # TODO: verify this works as expected -release = 'nightly' +release = "nightly" # Customized html_title here. # Default is " ".join(project, release, "documentation") if not set # TODO: I don't know if this flag works, please check before using it if RELEASE: - raise RuntimeError('NYI') + raise RuntimeError("NYI") # remove hash (start with 'a') from version number if any # version_end = functorch_version.find('a') # if version_end == -1: @@ -128,10 +131,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['notebooks/colab**', 'notebooks/_src/**'] +exclude_patterns = ["notebooks/colab**", "notebooks/_src/**"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -140,7 +143,7 @@ autodoc_inherit_docstrings = False # Disable displaying type annotations, these can be very verbose -autodoc_typehints = 'none' +autodoc_typehints = "none" # Enable overriding of function signatures in the first line of the docstring. autodoc_docstring_signature = True @@ -159,7 +162,7 @@ # # -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme @@ -178,10 +181,10 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'css/custom.css', + "css/custom.css", ] @@ -191,19 +194,20 @@ def setup(app): # and can be moved outside of this function (and the setup(app) function # can be deleted). html_css_files = [ - 'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css' + "https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css" ] # In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is # `add_stylesheet` (deprecated in 1.8). - add_css = getattr(app, 'add_css_file', app.add_stylesheet) + add_css = getattr(app, "add_css_file", app.add_stylesheet) for css_file in html_css_files: add_css(css_file) + # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'PyTorchdoc' +htmlhelp_basename = "PyTorchdoc" # -- Options for LaTeX output --------------------------------------------- @@ -212,15 +216,12 @@ def setup(app): # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -230,8 +231,13 @@ def setup(app): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'pytorch.tex', 'PyTorch Documentation', - 'Torch Contributors', 'manual'), + ( + master_doc, + "pytorch.tex", + "PyTorch Documentation", + "Torch Contributors", + "manual", + ), ] @@ -239,10 +245,7 @@ def setup(app): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'functorch', 'functorch Documentation', - [author], 1) -] +man_pages = [(master_doc, "functorch", "functorch Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -251,37 +254,44 @@ def setup(app): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'functorch', 'functorch Documentation', - author, 'functorch', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "functorch", + "functorch Documentation", + author, + "functorch", + "One line description of project.", + "Miscellaneous", + ), ] # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable/", None), } +import sphinx.ext.doctest + # -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # See http://stackoverflow.com/a/41184353/3343043 from docutils import nodes -from sphinx.util.docfields import TypedField from sphinx import addnodes -import sphinx.ext.doctest +from sphinx.util.docfields import TypedField # Without this, doctest adds any example with a `>>>` as a test -doctest_test_doctest_blocks = '' +doctest_test_doctest_blocks = "" doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS -doctest_global_setup = ''' +doctest_global_setup = """ import torch try: import torchvision except ImportError: torchvision = None -''' +""" def patched_make_field(self, types, domain, items, **kw): @@ -291,43 +301,51 @@ def patched_make_field(self, types, domain, items, **kw): # (List, unicode, Tuple) -> nodes.field def handle_item(fieldarg, content): par = nodes.paragraph() - par += addnodes.literal_strong('', fieldarg) # Patch: this line added + par += addnodes.literal_strong("", fieldarg) # Patch: this line added # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # addnodes.literal_strong)) if fieldarg in types: - par += nodes.Text(' (') + par += nodes.Text(" (") # NOTE: using .pop() here to prevent a single type node to be # inserted twice into the doctree, which leads to # inconsistencies later when references are resolved fieldtype = types.pop(fieldarg) if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): - typename = u''.join(n.astext() for n in fieldtype) - typename = typename.replace('int', 'python:int') - typename = typename.replace('long', 'python:long') - typename = typename.replace('float', 'python:float') - typename = typename.replace('bool', 'python:bool') - typename = typename.replace('type', 'python:type') - par.extend(self.make_xrefs(self.typerolename, domain, typename, - addnodes.literal_emphasis, **kw)) + typename = "".join(n.astext() for n in fieldtype) + typename = typename.replace("int", "python:int") + typename = typename.replace("long", "python:long") + typename = typename.replace("float", "python:float") + typename = typename.replace("bool", "python:bool") + typename = typename.replace("type", "python:type") + par.extend( + self.make_xrefs( + self.typerolename, + domain, + typename, + addnodes.literal_emphasis, + **kw, + ) + ) else: par += fieldtype - par += nodes.Text(')') - par += nodes.Text(' -- ') + par += nodes.Text(")") + par += nodes.Text(" -- ") par += content return par - fieldname = nodes.field_name('', self.label) + fieldname = nodes.field_name("", self.label) if len(items) == 1 and self.can_collapse: fieldarg, content = items[0] bodynode = handle_item(fieldarg, content) else: bodynode = self.list_type() for fieldarg, content in items: - bodynode += nodes.list_item('', handle_item(fieldarg, content)) - fieldbody = nodes.field_body('', bodynode) - return nodes.field('', fieldname, fieldbody) + bodynode += nodes.list_item("", handle_item(fieldarg, content)) + fieldbody = nodes.field_body("", bodynode) + return nodes.field("", fieldname, fieldbody) + TypedField.make_field = patched_make_field -copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True diff --git a/functorch/einops/__init__.py b/functorch/einops/__init__.py index af51695bc717f1..b32751d6e2493a 100644 --- a/functorch/einops/__init__.py +++ b/functorch/einops/__init__.py @@ -1,3 +1,3 @@ from .rearrange import rearrange -__all__ = ['rearrange'] +__all__ = ["rearrange"] diff --git a/functorch/einops/_parsing.py b/functorch/einops/_parsing.py index 56d1ee0f93839d..63adcb6e5a64c7 100644 --- a/functorch/einops/_parsing.py +++ b/functorch/einops/_parsing.py @@ -40,7 +40,9 @@ class AnonymousAxis: def __init__(self, value: str) -> None: self.value = int(value) if self.value < 1: - raise ValueError(f'Anonymous axis should have positive length, not {self.value}') + raise ValueError( + f"Anonymous axis should have positive length, not {self.value}" + ) def __repr__(self) -> str: return f"{self.value}-axis" @@ -49,7 +51,13 @@ def __repr__(self) -> str: class ParsedExpression: """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)').""" - def __init__(self, expression: str, *, allow_underscore: bool = False, allow_duplicates: bool = False) -> None: + def __init__( + self, + expression: str, + *, + allow_underscore: bool = False, + allow_duplicates: bool = False, + ) -> None: """Parse the expression and store relevant metadata. Args: @@ -66,10 +74,13 @@ def __init__(self, expression: str, *, allow_underscore: bool = False, allow_dup self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = [] if "." in expression: if "..." not in expression: - raise ValueError("Expression may contain dots only inside ellipsis (...)") + raise ValueError( + "Expression may contain dots only inside ellipsis (...)" + ) if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: raise ValueError( - "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ") + "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " + ) expression = expression.replace("...", _ellipsis) self.has_ellipsis = True @@ -78,7 +89,9 @@ def __init__(self, expression: str, *, allow_underscore: bool = False, allow_dup def add_axis_name(x: str) -> None: if x in self.identifiers: if not (allow_underscore and x == "_") and not allow_duplicates: - raise ValueError(f"Indexing expression contains duplicate dimension '{x}'") + raise ValueError( + f"Indexing expression contains duplicate dimension '{x}'" + ) if x == _ellipsis: self.identifiers.add(_ellipsis) if bracket_group is None: @@ -96,10 +109,14 @@ def add_axis_name(x: str) -> None: else: pass # no need to think about 1s inside parenthesis return - is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) + is_axis_name, reason = self.check_axis_name_return_reason( + x, allow_underscore=allow_underscore + ) if not (is_number or is_axis_name): raise ValueError(f"Invalid axis identifier: {x}\n{reason}") - axis_name: Union[str, AnonymousAxis] = AnonymousAxis(x) if is_number else x + axis_name: Union[str, AnonymousAxis] = ( + AnonymousAxis(x) if is_number else x + ) self.identifiers.add(axis_name) if is_number: self.has_non_unitary_anonymous_axes = True @@ -116,7 +133,9 @@ def add_axis_name(x: str) -> None: current_identifier = None if char == "(": if bracket_group is not None: - raise ValueError("Axis composition is one-level (brackets inside brackets not allowed)") + raise ValueError( + "Axis composition is one-level (brackets inside brackets not allowed)" + ) bracket_group = [] elif char == ")": if bracket_group is None: @@ -137,7 +156,9 @@ def add_axis_name(x: str) -> None: add_axis_name(current_identifier) @staticmethod - def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: + def check_axis_name_return_reason( + name: str, allow_underscore: bool = False + ) -> Tuple[bool, str]: """Check if the given axis name is valid, and a message explaining why if not. Valid axes names are python identifiers except keywords, and should not start or end with an underscore. @@ -157,10 +178,14 @@ def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> return False, "axis name should should not start or end with underscore" else: if keyword.iskeyword(name): - warnings.warn(f"It is discouraged to use axes names that are keywords: {name}", RuntimeWarning) + warnings.warn( + f"It is discouraged to use axes names that are keywords: {name}", + RuntimeWarning, + ) if name in ["axis"]: warnings.warn( - "It is discouraged to use 'axis' as an axis name and will raise an error in future", FutureWarning + "It is discouraged to use 'axis' as an axis name and will raise an error in future", + FutureWarning, ) return True, "" @@ -178,8 +203,9 @@ def check_axis_name(name: str) -> bool: return is_valid - -def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[ParsedExpression, ParsedExpression]: +def parse_pattern( + pattern: str, axes_lengths: Mapping[str, int] +) -> Tuple[ParsedExpression, ParsedExpression]: """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. Args: @@ -203,9 +229,13 @@ def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[Parsed right = ParsedExpression(right_str) if not left.has_ellipsis and right.has_ellipsis: - raise ValueError(f"Ellipsis found in right side, but not left side of a pattern {pattern}") + raise ValueError( + f"Ellipsis found in right side, but not left side of a pattern {pattern}" + ) if left.has_ellipsis and left.has_ellipsis_parenthesized: - raise ValueError(f"Ellipsis is parenthesis in the left side is not allowed: {pattern}") + raise ValueError( + f"Ellipsis is parenthesis in the left side is not allowed: {pattern}" + ) return left, right @@ -222,18 +252,24 @@ def validate_rearrange_expressions( """ for length in axes_lengths.values(): if (length_type := type(length)) is not int: - raise TypeError(f"rearrange axis lengths must be integers, got: {length_type}") + raise TypeError( + f"rearrange axis lengths must be integers, got: {length_type}" + ) if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes: raise ValueError("rearrange only supports unnamed axes of size 1") difference = set.symmetric_difference(left.identifiers, right.identifiers) if len(difference) > 0: - raise ValueError(f"Identifiers only on one side of rearrange expression (should be on both): {difference}") + raise ValueError( + f"Identifiers only on one side of rearrange expression (should be on both): {difference}" + ) unmatched_axes = axes_lengths.keys() - left.identifiers if len(unmatched_axes) > 0: - raise ValueError(f"Identifiers not found in rearrange expression: {unmatched_axes}") + raise ValueError( + f"Identifiers not found in rearrange expression: {unmatched_axes}" + ) def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: @@ -259,6 +295,8 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: '(d0,), (), (d1,), (d2,), (d3, d4)' """ return ", ".join( - item if isinstance(item, str) else f"({comma_separate(item)}{',' if len(item) == 1 else ''})" + item + if isinstance(item, str) + else f"({comma_separate(item)}{',' if len(item) == 1 else ''})" for item in collection ) diff --git a/functorch/einops/rearrange.py b/functorch/einops/rearrange.py index b35198cbec5cc7..0449bb7ed2c72e 100644 --- a/functorch/einops/rearrange.py +++ b/functorch/einops/rearrange.py @@ -4,8 +4,15 @@ from typing import Callable, Dict, List, Sequence, Tuple, Union import torch + from functorch._C import dim as _C -from ._parsing import AnonymousAxis, _ellipsis, comma_separate, parse_pattern, validate_rearrange_expressions +from ._parsing import ( + _ellipsis, + AnonymousAxis, + comma_separate, + parse_pattern, + validate_rearrange_expressions, +) __all__ = ["rearrange"] @@ -79,10 +86,12 @@ def _create_rearrange_callable( dims_i += 1 elif dimension == _ellipsis: identifier = _ellipsis - identifier_dim_map[identifier] = tuple(first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)) + identifier_dim_map[identifier] = tuple( + first_class_dims[dims_i + j] for j in range(n_ellipsis_dims) + ) dims_i += n_ellipsis_dims else: - raise ValueError(f'Unexpected dimension: {dimension}') + raise ValueError(f"Unexpected dimension: {dimension}") def composition_to_dims( composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]] @@ -92,11 +101,17 @@ class dims.""" dim_composition: List[Union[str, Tuple[str, ...]]] = [] for dimension in composition: if isinstance(dimension, list): - dim_composition.append(tuple(dim for identifier in dimension for dim in identifier_dim_map[identifier])) + dim_composition.append( + tuple( + dim + for identifier in dimension + for dim in identifier_dim_map[identifier] + ) + ) elif dimension == _ellipsis: dim_composition.extend(identifier_dim_map[_ellipsis]) else: - raise ValueError(f'Unexpected dimension: {dimension}') + raise ValueError(f"Unexpected dimension: {dimension}") return dim_composition left_dims = composition_to_dims(left.composition) @@ -108,16 +123,22 @@ class dims.""" custom_rearrange_callable_name = "do_rearrange" custom_rearrange_callable_code = ( - f"def {custom_rearrange_callable_name}(tensor):\n" - f" {comma_separate(first_class_dims)} = dims({n_dims})\n" + ( + f"def {custom_rearrange_callable_name}(tensor):\n" + f" {comma_separate(first_class_dims)} = dims({n_dims})\n" + ) + ( - "".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths) - if specified_lengths else "" + "".join( + f" {dim}.size = {length}\n" for (dim, length) in specified_lengths + ) + if specified_lengths + else "" ) + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n" + ( f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n" - if anon_dims else " return tensor\n" + if anon_dims + else " return tensor\n" ) ) @@ -126,7 +147,9 @@ class dims.""" def rearrange( - tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], pattern: str, **axes_lengths: int + tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], + pattern: str, + **axes_lengths: int, ) -> torch.Tensor: r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, @@ -177,6 +200,8 @@ def rearrange( if not isinstance(tensor, torch.Tensor): tensor = torch.stack(tensor) - rearrange_callable = _create_rearrange_callable(tensor.ndim, pattern, **axes_lengths) + rearrange_callable = _create_rearrange_callable( + tensor.ndim, pattern, **axes_lengths + ) return rearrange_callable(tensor) diff --git a/functorch/examples/compilation/eager_fusion.py b/functorch/examples/compilation/eager_fusion.py index cc43a5ce199703..bd24fd77fdd992 100644 --- a/functorch/examples/compilation/eager_fusion.py +++ b/functorch/examples/compilation/eager_fusion.py @@ -1,7 +1,8 @@ -from functorch.compile import aot_function, tvm_compile -import torch import time + +import torch import torch.utils +from functorch.compile import aot_function, tvm_compile a = torch.randn(2000, 1, 4, requires_grad=True) b = torch.randn(1, 2000, 4) @@ -11,8 +12,8 @@ def f(a): return (a * b).sum(dim=0) -fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') -bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') +fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops") +bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops") compiled_f = aot_function(f, fw_compiler, bw_compiler) # fw_compiler = lambda x, _: x @@ -32,13 +33,15 @@ def bench(func): def bench_jax(): - import jax.numpy as jnp import jax + import jax.numpy as jnp + jax_a = jnp.array(a.detach().numpy()) jax_b = jnp.array(b.detach().numpy()) def f(a): return jnp.sin((a * jax_b).sum(axis=[0])).sum() + jit_f = jax.jit(jax.grad(f)) jit_f(jax_a) begin = time.time() diff --git a/functorch/examples/compilation/fuse_module.py b/functorch/examples/compilation/fuse_module.py index 3d2f830485b927..a0eb60347714b3 100644 --- a/functorch/examples/compilation/fuse_module.py +++ b/functorch/examples/compilation/fuse_module.py @@ -1,15 +1,16 @@ import timeit -from functorch.compile import compiled_module, tvm_compile -import torch.nn as nn + import torch +import torch.nn as nn +from functorch.compile import compiled_module, tvm_compile def nop(f, _): return f -fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') -bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') +fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops") +bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops") fw_compiler = nop bw_compiler = nop diff --git a/functorch/examples/compilation/linear_train.py b/functorch/examples/compilation/linear_train.py index ee84347470835b..6bbf3f02337afb 100644 --- a/functorch/examples/compilation/linear_train.py +++ b/functorch/examples/compilation/linear_train.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functorch import make_functional -from functorch.compile import nnc_jit +import time + import torch import torch.nn as nn -import time +from functorch import make_functional +from functorch.compile import nnc_jit + torch._C._jit_override_can_fuse_on_cpu(True) @@ -30,7 +32,7 @@ def __init__(self, num_layers=3, features=100): self.mod = nn.Sequential(*mods) def forward(self, x): - return (self.mod(x)**2).sum() + return (self.mod(x) ** 2).sum() batch_size = 16 @@ -54,7 +56,9 @@ def functional_step(x, weights): return out, new_weights -optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0) +optim = torch.optim.SGD( + jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0 +) def jit_step(x, weights): diff --git a/functorch/examples/compilation/simple_function.py b/functorch/examples/compilation/simple_function.py index 14731c7c66661e..d916cc5b6ee492 100644 --- a/functorch/examples/compilation/simple_function.py +++ b/functorch/examples/compilation/simple_function.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time + +import torch from functorch import grad, make_fx from functorch.compile import nnc_jit -import torch -import time def f(x): diff --git a/functorch/examples/dp_cifar10/cifar10_opacus.py b/functorch/examples/dp_cifar10/cifar10_opacus.py index 22cd3ed92022ac..bd8239b187e662 100644 --- a/functorch/examples/dp_cifar10/cifar10_opacus.py +++ b/functorch/examples/dp_cifar10/cifar10_opacus.py @@ -17,8 +17,8 @@ import torch.optim as optim import torch.utils.data import torchvision.transforms as transforms -from torchvision import models from opacus import PrivacyEngine +from torchvision import models from torchvision.datasets import CIFAR10 from tqdm import tqdm @@ -52,7 +52,6 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): top1_acc = [] for i, (images, target) in enumerate(tqdm(train_loader)): - images = images.to(device) target = target.to(device) @@ -279,6 +278,7 @@ def main(): ) logger.info(metrics) + def parse_args(): parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") parser.add_argument( @@ -309,7 +309,7 @@ def parse_args(): default=256, type=int, metavar="N", - help="mini-batch size for test dataset (default: 256)" + help="mini-batch size for test dataset (default: 256)", ) parser.add_argument( "--sample-rate", diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py b/functorch/examples/dp_cifar10/cifar10_transforms.py index 863896983a08a2..bbab7f46b5e228 100644 --- a/functorch/examples/dp_cifar10/cifar10_transforms.py +++ b/functorch/examples/dp_cifar10/cifar10_transforms.py @@ -17,12 +17,12 @@ import torch.optim as optim import torch.utils.data import torchvision.transforms as transforms + +from torch.func import functional_call, grad_and_value, vmap from torchvision import models from torchvision.datasets import CIFAR10 from tqdm import tqdm -from torch.func import vmap, grad_and_value, functional_call - logging.basicConfig( format="%(asctime)s:%(levelname)s:%(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -44,12 +44,16 @@ def accuracy(preds, labels): def compute_norms(sample_grads): batch_size = sample_grads[0].shape[0] - norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads] + norms = [ + sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads + ] norms = torch.stack(norms, dim=0).norm(2, dim=0) return norms, batch_size -def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0): +def clip_and_accumulate_and_add_noise( + model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0 +): sample_grads = tuple(param.grad_sample for param in model.parameters()) # step 0: compute the norms @@ -60,18 +64,21 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise clip_factor = clip_factor.clamp(max=1.0) # step 2: clip - grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad) - for sample_grad in sample_grads) + grads = tuple( + torch.einsum("i,i...", clip_factor, sample_grad) for sample_grad in sample_grads + ) # step 3: add gaussian noise stddev = max_per_sample_grad_norm * noise_multiplier - noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device) - for grad_param in grads) + noises = tuple( + torch.normal(0, stddev, grad_param.shape, device=grad_param.device) + for grad_param in grads + ) grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads)) # step 4: assign the new grads, delete the sample grads for param, param_grad in zip(model.parameters(), grads): - param.grad = param_grad/batch_size + param.grad = param_grad / batch_size del param.grad_sample @@ -84,7 +91,6 @@ def train(args, model, train_loader, optimizer, epoch, device): top1_acc = [] for i, (images, target) in enumerate(tqdm(train_loader)): - images = images.to(device) target = target.to(device) @@ -120,8 +126,9 @@ def compute_loss_and_output(weights, image, target): # detaching weights since we don't need to track gradients outside of transforms # and this is more performant detached_weights = {k: v.detach() for k, v in weights.items()} - sample_grads, (sample_loss, output) = \ - vmap(grads_loss_output, (None, 0, 0))(detached_weights, images, target) + sample_grads, (sample_loss, output) = vmap(grads_loss_output, (None, 0, 0))( + detached_weights, images, target + ) loss = sample_loss.mean() for name, grad_sample in sample_grads.items(): @@ -129,7 +136,8 @@ def compute_loss_and_output(weights, image, target): # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise clip_and_accumulate_and_add_noise( - model, args.max_per_sample_grad_norm, args.sigma) + model, args.max_per_sample_grad_norm, args.sigma + ) preds = np.argmax(output.detach().cpu().numpy(), axis=1) labels = target.detach().cpu().numpy() @@ -270,9 +278,7 @@ def main(): for param_group in optimizer.param_groups: param_group["lr"] = lr - train_duration = train( - args, model, train_loader, optimizer, epoch, device - ) + train_duration = train(args, model, train_loader, optimizer, epoch, device) top1_acc = test(args, model, test_loader, device) # remember best acc@1 and save checkpoint @@ -308,6 +314,7 @@ def main(): ) logger.info(metrics) + def parse_args(): parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") parser.add_argument( @@ -338,7 +345,7 @@ def parse_args(): default=256, type=int, metavar="N", - help="mini-batch size for test dataset (default: 256)" + help="mini-batch size for test dataset (default: 256)", ) parser.add_argument( "--sample-rate", diff --git a/functorch/examples/ensembling/parallel_train.py b/functorch/examples/ensembling/parallel_train.py index 2077ffc2fb98a6..8fb0dae48e205a 100644 --- a/functorch/examples/ensembling/parallel_train.py +++ b/functorch/examples/ensembling/parallel_train.py @@ -1,9 +1,10 @@ import argparse import math + import torch import torch.nn as nn import torch.nn.functional as F -from torch.func import functional_call, grad_and_value, vmap, stack_module_state +from torch.func import functional_call, grad_and_value, stack_module_state, vmap # Adapted from http://willwhitney.com/parallel-training-jax.html , which is a # tutorial on Model Ensembling with JAX by Will Whitney. @@ -33,15 +34,21 @@ # Step 1: Make some spirals -def make_spirals(n_samples, noise_std=0., rotations=1.): +def make_spirals(n_samples, noise_std=0.0, rotations=1.0): ts = torch.linspace(0, 1, n_samples, device=DEVICE) - rs = ts ** 0.5 + rs = ts**0.5 thetas = rs * rotations * 2 * math.pi signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 labels = (signs > 0).to(torch.long).to(DEVICE) - xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std - ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + xs = ( + rs * signs * torch.cos(thetas) + + torch.randn(n_samples, device=DEVICE) * noise_std + ) + ys = ( + rs * signs * torch.sin(thetas) + + torch.randn(n_samples, device=DEVICE) * noise_std + ) points = torch.stack([xs, ys], dim=1) return points, labels @@ -70,6 +77,7 @@ def forward(self, x): loss_fn = nn.NLLLoss() model = MLPClassifier().to(DEVICE) + def train_step_fn(weights, batch, targets, lr=0.2): def compute_loss(weights, batch, targets): output = functional_call(model, weights, batch) @@ -109,6 +117,7 @@ def init_fn(num_models): params, _ = stack_module_state(models) return params + # Step 6: Now, can we try multiple models at the same time? # The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps # on decreasing diff --git a/functorch/examples/lennard_jones/lennard_jones.py b/functorch/examples/lennard_jones/lennard_jones.py index 98d1e5edcde507..1b3d248ede11dc 100644 --- a/functorch/examples/lennard_jones/lennard_jones.py +++ b/functorch/examples/lennard_jones/lennard_jones.py @@ -4,15 +4,15 @@ import torch from torch import nn -from torch.nn.functional import mse_loss from torch.func import jacrev, vmap +from torch.nn.functional import mse_loss sigma = 0.5 -epsilon = 4. +epsilon = 4.0 def lennard_jones(r): - return epsilon * ((sigma / r)**12 - (sigma / r)**6) + return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6) def lennard_jones_force(r): @@ -29,7 +29,9 @@ def lennard_jones_force(r): # Create training energies training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) # Create forces with random direction vectors -training_forces = torch.stack([force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]) +training_forces = torch.stack( + [force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)] +) model = nn.Sequential( nn.Linear(1, 16), @@ -40,7 +42,7 @@ def lennard_jones_force(r): nn.Tanh(), nn.Linear(16, 16), nn.Tanh(), - nn.Linear(16, 1) + nn.Linear(16, 1), ) @@ -54,7 +56,10 @@ def make_prediction(model, drs): def loss_fn(energies, forces, predicted_energies, predicted_forces): - return mse_loss(energies, predicted_energies) + 0.01 * mse_loss(forces, predicted_forces) / 3 + return ( + mse_loss(energies, predicted_energies) + + 0.01 * mse_loss(forces, predicted_forces) / 3 + ) optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) diff --git a/functorch/examples/maml_omniglot/maml-omniglot-higher.py b/functorch/examples/maml_omniglot/maml-omniglot-higher.py index 17a882dd33702d..fa38fd616ee392 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-higher.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-higher.py @@ -27,38 +27,43 @@ https://github.com/bamos/HowToTrainYourMAMLPytorch """ -from support.omniglot_loaders import OmniglotNShot -import higher -import torch.optim as optim -import torch.nn.functional as F -from torch import nn -import torch -import matplotlib.pyplot as plt import argparse import time -import pandas as pd -import numpy as np +import higher import matplotlib as mpl -mpl.use('Agg') -plt.style.use('bmh') +import matplotlib.pyplot as plt +import numpy as np + +import pandas as pd +import torch +import torch.nn.functional as F +import torch.optim as optim +from support.omniglot_loaders import OmniglotNShot +from torch import nn + +mpl.use("Agg") +plt.style.use("bmh") def main(): argparser = argparse.ArgumentParser() - argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5) - argparser.add_argument( - '--k-spt', '--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) argparser.add_argument( - '--k-qry', '--k_qry', type=int, help='k shot for query set', default=15) + "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 + ) argparser.add_argument( - '--device', type=str, help='device', default='cuda') + "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 + ) + argparser.add_argument("--device", type=str, help="device", default="cuda") argparser.add_argument( - '--task-num', '--task_num', + "--task-num", + "--task_num", type=int, - help='meta batch size, namely task num', - default=32) - argparser.add_argument('--seed', type=int, help='random seed', default=1) + help="meta batch size, namely task num", + default=32, + ) + argparser.add_argument("--seed", type=int, help="random seed", default=1) args = argparser.parse_args() torch.manual_seed(args.seed) @@ -69,7 +74,7 @@ def main(): # Set up the Omniglot loader. device = args.device db = OmniglotNShot( - '/tmp/omniglot-data', + "/tmp/omniglot-data", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, @@ -97,7 +102,8 @@ def main(): nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), Flatten(), - nn.Linear(64, args.n_way)).to(device) + nn.Linear(64, args.n_way), + ).to(device) # We will use Adam to (meta-)optimize the initial parameters # to be adapted. @@ -134,9 +140,10 @@ def train(db, net, device, meta_opt, epoch, log): qry_accs = [] meta_opt.zero_grad() for i in range(task_num): - with higher.innerloop_ctx( - net, inner_opt, copy_initial_weights=False - ) as (fnet, diffopt): + with higher.innerloop_ctx(net, inner_opt, copy_initial_weights=False) as ( + fnet, + diffopt, + ): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. @@ -153,8 +160,7 @@ def train(db, net, device, meta_opt, epoch, log): qry_logits = fnet(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) qry_losses.append(qry_loss.detach()) - qry_acc = (qry_logits.argmax( - dim=1) == y_qry[i]).sum().item() / querysz + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz qry_accs.append(qry_acc) # print([b.shape for b in fnet[1].buffers()]) @@ -166,21 +172,23 @@ def train(db, net, device, meta_opt, epoch, log): meta_opt.step() qry_losses = sum(qry_losses) / task_num - qry_accs = 100. * sum(qry_accs) / task_num + qry_accs = 100.0 * sum(qry_accs) / task_num i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time if batch_idx % 4 == 0: print( - f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" ) - log.append({ - 'epoch': i, - 'loss': qry_losses, - 'acc': qry_accs, - 'mode': 'train', - 'time': time.time(), - }) + log.append( + { + "epoch": i, + "loss": qry_losses, + "acc": qry_accs, + "mode": "train", + "time": time.time(), + } + ) def test(db, net, device, epoch, log): @@ -196,7 +204,7 @@ def test(db, net, device, epoch, log): qry_accs = [] for _ in range(n_test_iter): - x_spt, y_spt, x_qry, y_qry = db.next('test') + x_spt, y_spt, x_qry, y_qry = db.next("test") task_num, setsz, c_, h, w = x_spt.size() @@ -206,7 +214,10 @@ def test(db, net, device, epoch, log): inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) for i in range(task_num): - with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt): + with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as ( + fnet, + diffopt, + ): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. @@ -217,24 +228,22 @@ def test(db, net, device, epoch, log): # The query loss and acc induced by these parameters. qry_logits = fnet(x_qry[i]).detach() - qry_loss = F.cross_entropy( - qry_logits, y_qry[i], reduction='none') + qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") qry_losses.append(qry_loss.detach()) - qry_accs.append( - (qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100. * torch.cat(qry_accs).float().mean().item() - print( - f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' + qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") + log.append( + { + "epoch": epoch + 1, + "loss": qry_losses, + "acc": qry_accs, + "mode": "test", + "time": time.time(), + } ) - log.append({ - 'epoch': epoch + 1, - 'loss': qry_losses, - 'acc': qry_accs, - 'mode': 'test', - 'time': time.time(), - }) def plot(log): @@ -243,17 +252,17 @@ def plot(log): df = pd.DataFrame(log) fig, ax = plt.subplots(figsize=(6, 4)) - train_df = df[df['mode'] == 'train'] - test_df = df[df['mode'] == 'test'] - ax.plot(train_df['epoch'], train_df['acc'], label='Train') - ax.plot(test_df['epoch'], test_df['acc'], label='Test') - ax.set_xlabel('Epoch') - ax.set_ylabel('Accuracy') + train_df = df[df["mode"] == "train"] + test_df = df[df["mode"] == "test"] + ax.plot(train_df["epoch"], train_df["acc"], label="Train") + ax.plot(test_df["epoch"], test_df["acc"], label="Test") + ax.set_xlabel("Epoch") + ax.set_ylabel("Accuracy") ax.set_ylim(70, 100) - fig.legend(ncol=2, loc='lower right') + fig.legend(ncol=2, loc="lower right") fig.tight_layout() - fname = 'maml-accs.png' - print(f'--- Plotting accuracy to {fname}') + fname = "maml-accs.png" + print(f"--- Plotting accuracy to {fname}") fig.savefig(fname) plt.close(fig) @@ -265,5 +274,5 @@ def forward(self, input): return input.view(input.size(0), -1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py index 3040df681ab123..5881f7963ab025 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py @@ -27,38 +27,43 @@ https://github.com/bamos/HowToTrainYourMAMLPytorch """ -from support.omniglot_loaders import OmniglotNShot -from functorch import make_functional_with_buffers -import torch.optim as optim -import torch.nn.functional as F -from torch import nn -import torch -import matplotlib.pyplot as plt import argparse import time -import pandas as pd -import numpy as np import matplotlib as mpl -mpl.use('Agg') -plt.style.use('bmh') +import matplotlib.pyplot as plt +import numpy as np + +import pandas as pd +import torch +import torch.nn.functional as F +import torch.optim as optim +from functorch import make_functional_with_buffers +from support.omniglot_loaders import OmniglotNShot +from torch import nn + +mpl.use("Agg") +plt.style.use("bmh") def main(): argparser = argparse.ArgumentParser() - argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5) - argparser.add_argument( - '--k-spt', '--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) argparser.add_argument( - '--k-qry', '--k_qry', type=int, help='k shot for query set', default=15) + "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 + ) argparser.add_argument( - '--device', type=str, help='device', default='cuda') + "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 + ) + argparser.add_argument("--device", type=str, help="device", default="cuda") argparser.add_argument( - '--task-num', '--task_num', + "--task-num", + "--task_num", type=int, - help='meta batch size, namely task num', - default=32) - argparser.add_argument('--seed', type=int, help='random seed', default=1) + help="meta batch size, namely task num", + default=32, + ) + argparser.add_argument("--seed", type=int, help="random seed", default=1) args = argparser.parse_args() torch.manual_seed(args.seed) @@ -69,7 +74,7 @@ def main(): # Set up the Omniglot loader. device = args.device db = OmniglotNShot( - '/tmp/omniglot-data', + "/tmp/omniglot-data", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, @@ -97,7 +102,8 @@ def main(): nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), Flatten(), - nn.Linear(64, args.n_way)).to(device) + nn.Linear(64, args.n_way), + ).to(device) net.train() fnet, params, buffers = make_functional_with_buffers(net) @@ -153,8 +159,7 @@ def train(db, net, device, meta_opt, epoch, log): qry_logits = fnet(new_params, buffers, x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) qry_losses.append(qry_loss.detach()) - qry_acc = (qry_logits.argmax( - dim=1) == y_qry[i]).sum().item() / querysz + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz qry_accs.append(qry_acc) # Update the model's meta-parameters to optimize the query @@ -164,21 +169,23 @@ def train(db, net, device, meta_opt, epoch, log): meta_opt.step() qry_losses = sum(qry_losses) / task_num - qry_accs = 100. * sum(qry_accs) / task_num + qry_accs = 100.0 * sum(qry_accs) / task_num i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time if batch_idx % 4 == 0: print( - f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" ) - log.append({ - 'epoch': i, - 'loss': qry_losses, - 'acc': qry_accs, - 'mode': 'train', - 'time': time.time(), - }) + log.append( + { + "epoch": i, + "loss": qry_losses, + "acc": qry_accs, + "mode": "train", + "time": time.time(), + } + ) def test(db, net, device, epoch, log): @@ -194,7 +201,7 @@ def test(db, net, device, epoch, log): qry_accs = [] for batch_idx in range(n_test_iter): - x_spt, y_spt, x_qry, y_qry = db.next('test') + x_spt, y_spt, x_qry, y_qry = db.next("test") task_num, setsz, c_, h, w = x_spt.size() # TODO: Maybe pull this out into a separate module so it @@ -211,24 +218,22 @@ def test(db, net, device, epoch, log): # The query loss and acc induced by these parameters. qry_logits = fnet(new_params, buffers, x_qry[i]).detach() - qry_loss = F.cross_entropy( - qry_logits, y_qry[i], reduction='none') + qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") qry_losses.append(qry_loss.detach()) - qry_accs.append( - (qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100. * torch.cat(qry_accs).float().mean().item() - print( - f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' + qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") + log.append( + { + "epoch": epoch + 1, + "loss": qry_losses, + "acc": qry_accs, + "mode": "test", + "time": time.time(), + } ) - log.append({ - 'epoch': epoch + 1, - 'loss': qry_losses, - 'acc': qry_accs, - 'mode': 'test', - 'time': time.time(), - }) def plot(log): @@ -237,17 +242,17 @@ def plot(log): df = pd.DataFrame(log) fig, ax = plt.subplots(figsize=(6, 4)) - train_df = df[df['mode'] == 'train'] - test_df = df[df['mode'] == 'test'] - ax.plot(train_df['epoch'], train_df['acc'], label='Train') - ax.plot(test_df['epoch'], test_df['acc'], label='Test') - ax.set_xlabel('Epoch') - ax.set_ylabel('Accuracy') + train_df = df[df["mode"] == "train"] + test_df = df[df["mode"] == "test"] + ax.plot(train_df["epoch"], train_df["acc"], label="Train") + ax.plot(test_df["epoch"], test_df["acc"], label="Test") + ax.set_xlabel("Epoch") + ax.set_ylabel("Accuracy") ax.set_ylim(70, 100) - fig.legend(ncol=2, loc='lower right') + fig.legend(ncol=2, loc="lower right") fig.tight_layout() - fname = 'maml-accs.png' - print(f'--- Plotting accuracy to {fname}') + fname = "maml-accs.png" + print(f"--- Plotting accuracy to {fname}") fig.savefig(fname) plt.close(fig) @@ -259,5 +264,5 @@ def forward(self, input): return input.view(input.size(0), -1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py index 7883d77aaff768..ba55e580478872 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py @@ -27,39 +27,44 @@ https://github.com/bamos/HowToTrainYourMAMLPytorch """ -from support.omniglot_loaders import OmniglotNShot -from torch.func import vmap, grad, functional_call -import torch.optim as optim -import torch.nn.functional as F -from torch import nn -import torch -import matplotlib.pyplot as plt import argparse -import time import functools +import time -import pandas as pd -import numpy as np import matplotlib as mpl -mpl.use('Agg') -plt.style.use('bmh') +import matplotlib.pyplot as plt +import numpy as np + +import pandas as pd +import torch +import torch.nn.functional as F +import torch.optim as optim +from support.omniglot_loaders import OmniglotNShot +from torch import nn +from torch.func import functional_call, grad, vmap + +mpl.use("Agg") +plt.style.use("bmh") def main(): argparser = argparse.ArgumentParser() - argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5) - argparser.add_argument( - '--k-spt', '--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) argparser.add_argument( - '--k-qry', '--k_qry', type=int, help='k shot for query set', default=15) + "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 + ) argparser.add_argument( - '--device', type=str, help='device', default='cuda') + "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 + ) + argparser.add_argument("--device", type=str, help="device", default="cuda") argparser.add_argument( - '--task-num', '--task_num', + "--task-num", + "--task_num", type=int, - help='meta batch size, namely task num', - default=32) - argparser.add_argument('--seed', type=int, help='random seed', default=1) + help="meta batch size, namely task num", + default=32, + ) + argparser.add_argument("--seed", type=int, help="random seed", default=1) args = argparser.parse_args() torch.manual_seed(args.seed) @@ -70,7 +75,7 @@ def main(): # Set up the Omniglot loader. device = args.device db = OmniglotNShot( - '/tmp/omniglot-data', + "/tmp/omniglot-data", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, @@ -95,7 +100,8 @@ def main(): nn.ReLU(inplace=inplace_relu), nn.MaxPool2d(2, 2), nn.Flatten(), - nn.Linear(64, args.n_way)).to(device) + nn.Linear(64, args.n_way), + ).to(device) net.train() @@ -132,8 +138,7 @@ def compute_loss(new_params, buffers, x, y): # These will be used to update the model's meta-parameters. qry_logits = functional_call(net, (new_params, buffers), x_qry) qry_loss = F.cross_entropy(qry_logits, y_qry) - qry_acc = (qry_logits.argmax( - dim=1) == y_qry).sum() / querysz + qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz return qry_loss, qry_acc @@ -163,21 +168,23 @@ def train(db, net, device, meta_opt, epoch, log): meta_opt.step() qry_losses = qry_losses.detach().sum() / task_num - qry_accs = 100. * qry_accs.sum() / task_num + qry_accs = 100.0 * qry_accs.sum() / task_num i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time if batch_idx % 4 == 0: print( - f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" ) - log.append({ - 'epoch': i, - 'loss': qry_losses, - 'acc': qry_accs, - 'mode': 'train', - 'time': time.time(), - }) + log.append( + { + "epoch": i, + "loss": qry_losses, + "acc": qry_accs, + "mode": "train", + "time": time.time(), + } + ) def test(db, net, device, epoch, log): @@ -194,7 +201,7 @@ def test(db, net, device, epoch, log): qry_accs = [] for batch_idx in range(n_test_iter): - x_spt, y_spt, x_qry, y_qry = db.next('test') + x_spt, y_spt, x_qry, y_qry = db.next("test") task_num, setsz, c_, h, w = x_spt.size() # TODO: Maybe pull this out into a separate module so it @@ -207,28 +214,28 @@ def test(db, net, device, epoch, log): spt_logits = functional_call(net, (new_params, buffers), x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) grads = torch.autograd.grad(spt_loss, new_params.values()) - new_params = {k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)} + new_params = { + k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads) + } # The query loss and acc induced by these parameters. qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach() - qry_loss = F.cross_entropy( - qry_logits, y_qry[i], reduction='none') + qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") qry_losses.append(qry_loss.detach()) - qry_accs.append( - (qry_logits.argmax(dim=1) == y_qry[i]).detach()) + qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() - qry_accs = 100. * torch.cat(qry_accs).float().mean().item() - print( - f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' + qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") + log.append( + { + "epoch": epoch + 1, + "loss": qry_losses, + "acc": qry_accs, + "mode": "test", + "time": time.time(), + } ) - log.append({ - 'epoch': epoch + 1, - 'loss': qry_losses, - 'acc': qry_accs, - 'mode': 'test', - 'time': time.time(), - }) def plot(log): @@ -237,20 +244,20 @@ def plot(log): df = pd.DataFrame(log) fig, ax = plt.subplots(figsize=(6, 4)) - train_df = df[df['mode'] == 'train'] - test_df = df[df['mode'] == 'test'] - ax.plot(train_df['epoch'], train_df['acc'], label='Train') - ax.plot(test_df['epoch'], test_df['acc'], label='Test') - ax.set_xlabel('Epoch') - ax.set_ylabel('Accuracy') + train_df = df[df["mode"] == "train"] + test_df = df[df["mode"] == "test"] + ax.plot(train_df["epoch"], train_df["acc"], label="Train") + ax.plot(test_df["epoch"], test_df["acc"], label="Test") + ax.set_xlabel("Epoch") + ax.set_ylabel("Accuracy") ax.set_ylim(70, 100) - fig.legend(ncol=2, loc='lower right') + fig.legend(ncol=2, loc="lower right") fig.tight_layout() - fname = 'maml-accs.png' - print(f'--- Plotting accuracy to {fname}') + fname = "maml-accs.png" + print(f"--- Plotting accuracy to {fname}") fig.savefig(fname) plt.close(fig) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/functorch/examples/maml_omniglot/support/omniglot_loaders.py b/functorch/examples/maml_omniglot/support/omniglot_loaders.py index 4aa9e3b96f4710..8d7a95659a8b25 100644 --- a/functorch/examples/maml_omniglot/support/omniglot_loaders.py +++ b/functorch/examples/maml_omniglot/support/omniglot_loaders.py @@ -17,38 +17,38 @@ # https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py # https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py -import torchvision.transforms as transforms -from PIL import Image +import errno +import os +import os.path + import numpy as np import torch import torch.utils.data as data -import os -import os.path -import errno +import torchvision.transforms as transforms +from PIL import Image class Omniglot(data.Dataset): urls = [ - 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', - 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' + "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip", + "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip", ] - raw_folder = 'raw' - processed_folder = 'processed' - training_file = 'training.pt' - test_file = 'test.pt' + raw_folder = "raw" + processed_folder = "processed" + training_file = "training.pt" + test_file = "test.pt" - ''' + """ The items are (filename,category). The index of all the categories can be found in self.idx_classes Args: - root: the directory where the dataset will be stored - transform: how to transform the input - target_transform: how to transform the target - download: need to download the dataset - ''' + """ - def __init__(self, root, transform=None, target_transform=None, - download=False): + def __init__(self, root, transform=None, target_transform=None, download=False): self.root = root self.transform = transform self.target_transform = target_transform @@ -57,14 +57,16 @@ def __init__(self, root, transform=None, target_transform=None, if download: self.download() else: - raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') + raise RuntimeError( + "Dataset not found." + " You can use download=True to download it" + ) self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) self.idx_classes = index_classes(self.all_items) def __getitem__(self, index): filename = self.all_items[index][0] - img = str.join('/', [self.all_items[index][2], filename]) + img = str.join("/", [self.all_items[index][2], filename]) target = self.idx_classes[self.all_items[index][1]] if self.transform is not None: @@ -78,8 +80,11 @@ def __len__(self): return len(self.all_items) def _check_exists(self): - return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ - os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) + return os.path.exists( + os.path.join(self.root, self.processed_folder, "images_evaluation") + ) and os.path.exists( + os.path.join(self.root, self.processed_folder, "images_background") + ) def download(self): import urllib @@ -99,15 +104,15 @@ def download(self): raise for url in self.urls: - print('== Downloading ' + url) + print("== Downloading " + url) data = urllib.request.urlopen(url) - filename = url.rpartition('/')[2] + filename = url.rpartition("/")[2] file_path = os.path.join(self.root, self.raw_folder, filename) - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: f.write(data.read()) file_processed = os.path.join(self.root, self.processed_folder) print("== Unzip from " + file_path + " to " + file_processed) - zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref = zipfile.ZipFile(file_path, "r") zip_ref.extractall(file_processed) zip_ref.close() print("Download finished.") @@ -115,10 +120,10 @@ def download(self): def find_classes(root_dir): retour = [] - for (root, dirs, files) in os.walk(root_dir): + for root, dirs, files in os.walk(root_dir): for f in files: - if (f.endswith("png")): - r = root.split('/') + if f.endswith("png"): + r = root.split("/") lr = len(r) retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) print(f"== Found {len(retour)} items ") @@ -135,7 +140,6 @@ def index_classes(items): class OmniglotNShot: - def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): """ Different from mnistNShot, the @@ -149,41 +153,52 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): self.resize = imgsz self.device = device - if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + if not os.path.isfile(os.path.join(root, "omniglot.npy")): # if root/data.npy does not exist, just download it self.x = Omniglot( - root, download=True, + root, + download=True, transform=transforms.Compose( - [lambda x: Image.open(x).convert('L'), - lambda x: x.resize((imgsz, imgsz)), - lambda x: np.reshape(x, (imgsz, imgsz, 1)), - lambda x: np.transpose(x, [2, 0, 1]), - lambda x: x / 255.]), + [ + lambda x: Image.open(x).convert("L"), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.0, + ] + ), ) - temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} - for (img, label) in self.x: + temp = ( + {} + ) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} + for img, label in self.x: if label in temp.keys(): temp[label].append(img) else: temp[label] = [img] self.x = [] - for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs + for ( + label, + imgs, + ) in temp.items(): # labels info deserted , each label contains 20imgs self.x.append(np.array(imgs)) # as different class may have different number of imgs - self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + self.x = np.array(self.x).astype( + np.float + ) # [[20 imgs],..., 1623 classes in total] # each character contains 20 imgs - print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + print("data shape:", self.x.shape) # [1623, 20, 84, 84, 1] temp = [] # Free memory # save all dataset into npy file. - np.save(os.path.join(root, 'omniglot.npy'), self.x) - print('write into omniglot.npy.') + np.save(os.path.join(root, "omniglot.npy"), self.x) + print("write into omniglot.npy.") else: # if data.npy exists, just load it. - self.x = np.load(os.path.join(root, 'omniglot.npy')) - print('load from omniglot.npy.') + self.x = np.load(os.path.join(root, "omniglot.npy")) + print("load from omniglot.npy.") # [1623, 20, 84, 84, 1] # TODO: can not shuffle here, we must keep training and test set distinct! @@ -200,11 +215,18 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): # save pointer of current read batch in total cache self.indexes = {"train": 0, "test": 0} - self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached + self.datasets = { + "train": self.x_train, + "test": self.x_test, + } # original data cached print("DB: train", self.x_train.shape, "test", self.x_test.shape) - self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached - "test": self.load_data_cache(self.datasets["test"])} + self.datasets_cache = { + "train": self.load_data_cache( + self.datasets["train"] + ), # current epoch data cached + "test": self.load_data_cache(self.datasets["test"]), + } def normalization(self): """ @@ -238,29 +260,32 @@ def load_data_cache(self, data_pack): # print('preload next 50 caches of batchsz of batch.') for sample in range(10): # num of episodes - x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] for i in range(self.batchsz): # one batch means one set - x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate(selected_cls): - - selected_img = np.random.choice(20, self.k_shot + self.k_query, False) + selected_img = np.random.choice( + 20, self.k_shot + self.k_query, False + ) # meta-training and meta-test - x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]]) - x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]]) + x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) y_spt.append([j for _ in range(self.k_shot)]) y_qry.append([j for _ in range(self.k_query)]) # shuffle inside a batch perm = np.random.permutation(self.n_way * self.k_shot) - x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm] + x_spt = np.array(x_spt).reshape( + self.n_way * self.k_shot, 1, self.resize, self.resize + )[perm] y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] perm = np.random.permutation(self.n_way * self.k_query) - x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm] + x_qry = np.array(x_qry).reshape( + self.n_way * self.k_query, 1, self.resize, self.resize + )[perm] y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] @@ -270,22 +295,30 @@ def load_data_cache(self, data_pack): y_qrys.append(y_qry) # [b, setsz, 1, 84, 84] - x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize) + x_spts = ( + np.array(x_spts) + .astype(np.float32) + .reshape(self.batchsz, setsz, 1, self.resize, self.resize) + ) y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz) # [b, qrysz, 1, 84, 84] - x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize) + x_qrys = ( + np.array(x_qrys) + .astype(np.float32) + .reshape(self.batchsz, querysz, 1, self.resize, self.resize) + ) y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz) x_spts, y_spts, x_qrys, y_qrys = ( - torch.from_numpy(z).to(self.device) for z in - [x_spts, y_spts, x_qrys, y_qrys] + torch.from_numpy(z).to(self.device) + for z in [x_spts, y_spts, x_qrys, y_qrys] ) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) return data_cache - def next(self, mode='train'): + def next(self, mode="train"): """ Gets next batch from the dataset with name. :param mode: The name of the splitting (one of "train", "val", "test") diff --git a/functorch/examples/maml_regression/evjang.py b/functorch/examples/maml_regression/evjang.py index fcd7a3b2924066..cd0ee575c4958a 100644 --- a/functorch/examples/maml_regression/evjang.py +++ b/functorch/examples/maml_regression/evjang.py @@ -2,13 +2,15 @@ # (https://github.com/ericjang/maml-jax). # We translated his implementation from JAX to PyTorch. -import matplotlib.pyplot as plt import math -import torch + +import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np +import torch from torch.nn import functional as F -import matplotlib as mpl -mpl.use('Agg') + +mpl.use("Agg") def net(x, params): @@ -23,13 +25,15 @@ def net(x, params): params = [ - torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), + torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(), torch.Tensor(40).zero_().requires_grad_(), - - torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(40, 40) + .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40)) + .requires_grad_(), torch.Tensor(40).zero_().requires_grad_(), - - torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(1, 40) + .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40)) + .requires_grad_(), torch.Tensor(1).zero_().requires_grad_(), ] @@ -46,17 +50,18 @@ def sample_tasks(outer_batch_size, inner_batch_size): As = [] phases = [] for _ in range(outer_batch_size): - As.append(np.random.uniform(low=0.1, high=.5)) - phases.append(np.random.uniform(low=0., high=np.pi)) + As.append(np.random.uniform(low=0.1, high=0.5)) + phases.append(np.random.uniform(low=0.0, high=np.pi)) def get_batch(): xs, ys = [], [] for A, phase in zip(As, phases): - x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1)) y = A * np.sin(x + phase) xs.append(x) ys.append(y) return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) + x1, y1 = get_batch() x2, y2 = get_batch() return x1, y1, x2, y2 @@ -80,14 +85,17 @@ def get_loss_for_task(x1, y1, x2, y2): return F.mse_loss(v_f, y2) task = sample_tasks(num_tasks, K) - inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)] + inner_losses = [ + get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) + for i in range(num_tasks) + ] loss2 = sum(inner_losses) / len(inner_losses) loss2.backward() opt.step() if it % 100 == 0: - print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) + print("Iteration %d -- Outer Loss: %.4f" % (it, loss2)) losses.append(loss2.detach()) t_A = torch.tensor(0.0).uniform_(0.1, 0.5) @@ -112,11 +120,11 @@ def get_loss_for_task(x1, y1, x2, y2): test_f = net(test_x, t_params) -plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') -plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') -plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') +plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)") +plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)") +plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples") plt.legend() -plt.savefig('maml-sine.png') +plt.savefig("maml-sine.png") plt.figure() -plt.plot(np.convolve(losses, [.05] * 20)) -plt.savefig('losses.png') +plt.plot(np.convolve(losses, [0.05] * 20)) +plt.savefig("losses.png") diff --git a/functorch/examples/maml_regression/evjang_transforms.py b/functorch/examples/maml_regression/evjang_transforms.py index 13c2027a450c6e..c70fe5e1cde91c 100644 --- a/functorch/examples/maml_regression/evjang_transforms.py +++ b/functorch/examples/maml_regression/evjang_transforms.py @@ -2,14 +2,16 @@ # (https://github.com/ericjang/maml-jax). # We translated his implementation from JAX to PyTorch. -from torch.func import grad, vmap -import matplotlib.pyplot as plt import math -import torch + +import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np +import torch +from torch.func import grad, vmap from torch.nn import functional as F -import matplotlib as mpl -mpl.use('Agg') + +mpl.use("Agg") def net(params, x): @@ -24,13 +26,15 @@ def net(params, x): params = [ - torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), + torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(), torch.Tensor(40).zero_().requires_grad_(), - - torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(40, 40) + .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40)) + .requires_grad_(), torch.Tensor(40).zero_().requires_grad_(), - - torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(1, 40) + .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40)) + .requires_grad_(), torch.Tensor(1).zero_().requires_grad_(), ] @@ -54,17 +58,18 @@ def sample_tasks(outer_batch_size, inner_batch_size): As = [] phases = [] for _ in range(outer_batch_size): - As.append(np.random.uniform(low=0.1, high=.5)) - phases.append(np.random.uniform(low=0., high=np.pi)) + As.append(np.random.uniform(low=0.1, high=0.5)) + phases.append(np.random.uniform(low=0.0, high=np.pi)) def get_batch(): xs, ys = [], [] for A, phase in zip(As, phases): - x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1)) y = A * np.sin(x + phase) xs.append(x) ys.append(y) return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) + x1, y1 = get_batch() x2, y2 = get_batch() return x1, y1, x2, y2 @@ -94,7 +99,7 @@ def inner_loss(params, x1, y1): opt.step() if it % 100 == 0: - print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) + print("Iteration %d -- Outer Loss: %.4f" % (it, loss2)) losses.append(loss2.detach()) t_A = torch.tensor(0.0).uniform_(0.1, 0.5) @@ -119,11 +124,11 @@ def inner_loss(params, x1, y1): test_f = net(t_params, test_x) -plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') -plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') -plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') +plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)") +plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)") +plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples") plt.legend() -plt.savefig('maml-sine.png') +plt.savefig("maml-sine.png") plt.figure() -plt.plot(np.convolve(losses, [.05] * 20)) -plt.savefig('losses.png') +plt.plot(np.convolve(losses, [0.05] * 20)) +plt.savefig("losses.png") diff --git a/functorch/examples/maml_regression/evjang_transforms_module.py b/functorch/examples/maml_regression/evjang_transforms_module.py index cc333ba46077e4..3a2b64f77ec500 100644 --- a/functorch/examples/maml_regression/evjang_transforms_module.py +++ b/functorch/examples/maml_regression/evjang_transforms_module.py @@ -2,15 +2,17 @@ # (https://github.com/ericjang/maml-jax). # We translated his implementation from JAX to PyTorch. -from functorch import grad, vmap, make_functional -import matplotlib.pyplot as plt import math -import torch + +import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np +import torch +from functorch import grad, make_functional, vmap from torch import nn from torch.nn import functional as F -import matplotlib as mpl -mpl.use('Agg') + +mpl.use("Agg") class ThreeLayerNet(nn.Module): @@ -30,6 +32,7 @@ def forward(self, x): x = self.fc3(x) return x + # TODO: Use F.mse_loss @@ -51,17 +54,18 @@ def sample_tasks(outer_batch_size, inner_batch_size): As = [] phases = [] for _ in range(outer_batch_size): - As.append(np.random.uniform(low=0.1, high=.5)) - phases.append(np.random.uniform(low=0., high=np.pi)) + As.append(np.random.uniform(low=0.1, high=0.5)) + phases.append(np.random.uniform(low=0.0, high=np.pi)) def get_batch(): xs, ys = [], [] for A, phase in zip(As, phases): - x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1)) y = A * np.sin(x + phase) xs.append(x) ys.append(y) return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) + x1, y1 = get_batch() x2, y2 = get_batch() return x1, y1, x2, y2 @@ -91,7 +95,7 @@ def inner_loss(params, x1, y1): opt.step() if it % 100 == 0: - print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) + print("Iteration %d -- Outer Loss: %.4f" % (it, loss2)) losses.append(loss2.detach()) t_A = torch.tensor(0.0).uniform_(0.1, 0.5) @@ -116,11 +120,11 @@ def inner_loss(params, x1, y1): test_f = net(t_params, test_x) -plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') -plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') -plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') +plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)") +plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)") +plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples") plt.legend() -plt.savefig('maml-sine.png') +plt.savefig("maml-sine.png") plt.figure() -plt.plot(np.convolve(losses, [.05] * 20)) -plt.savefig('losses.png') +plt.plot(np.convolve(losses, [0.05] * 20)) +plt.savefig("losses.png") diff --git a/functorch/experimental/__init__.py b/functorch/experimental/__init__.py index 655534c2a3756f..4b0a8d0fe430df 100644 --- a/functorch/experimental/__init__.py +++ b/functorch/experimental/__init__.py @@ -1,5 +1,6 @@ # PyTorch forward-mode is not mature yet +from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.eager_transforms import hessian, jacfwd, jvp from torch._functorch.vmap import chunk_vmap -from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ + from functorch import functionalize diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py index 6b2e5c844c3c07..9e521cf02d1026 100644 --- a/functorch/experimental/_cond.py +++ b/functorch/experimental/_cond.py @@ -1,26 +1,31 @@ from dataclasses import dataclass + import torch -from torch.multiprocessing.reductions import StorageWeakRef import torch.utils._pytree as pytree -from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard -from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize +from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet +from torch._dynamo.exc import CondOpArgsMismatchError +from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + _wrap_all_tensors_to_functional, + functionalize, +) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, - ProxyTorchDispatchMode, make_fx, + ProxyTorchDispatchMode, track_tensor_tree, ) from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( _get_current_dispatch_mode, _pop_mode_temporarily, ) from torch.utils._pytree import tree_flatten -from torch._dynamo.exc import CondOpArgsMismatchError @dataclass @@ -34,9 +39,14 @@ class UnsupportedAliasMutationException(RuntimeError): """ cond = HigherOrderOperator("cond") + def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): - assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors" - assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors" + assert isinstance( + operands, (list, tuple) + ), "Cond operands must be a list or tuple of tensors" + assert all( + isinstance(o, torch.Tensor) for o in operands + ), "Cond operands must be a list of tensors" with disable_proxy_modes_tracing(): true_graph = make_fx(true_fn)(*operands) @@ -45,11 +55,11 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): true_outs = [] false_outs = [] for node in true_graph.graph.nodes: - if node.op == 'output': + if node.op == "output": true_outs.extend(node.args) for node in false_graph.graph.nodes: - if node.op == 'output': + if node.op == "output": false_outs.extend(node.args) flat_true_outs, _ = pytree.tree_flatten(true_outs) @@ -64,7 +74,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): for i in range(0, len(flat_true_outs)): true_out = flat_true_outs[i] false_out = flat_false_outs[i] - if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']: + if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: raise CondOpArgsMismatchError( f"Expected each tensor to have same metadata but got:" f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" @@ -85,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): true_name = next_name false_name = f"false_graph_{i}" - assert(not hasattr(proxy_mode.tracer.root, false_name)) + assert not hasattr(proxy_mode.tracer.root, false_name) proxy_mode.tracer.root.register_module(true_name, true_graph) proxy_mode.tracer.root.register_module(false_name, false_graph) @@ -94,8 +104,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) - out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, - name="conditional") + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="conditional" + ) # At this point, we're *guaranteed* that whether an output came from the # true or false branch is indistinguishable. So, as this is just for tracing @@ -112,7 +123,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): @cond.py_impl(DispatchKey.CompositeExplicitAutograd) def cond_dense(pred, true_fn, false_fn, operands): mode = _get_current_dispatch_mode() - assert (mode is None), "Mode should never be enabled for CPU/CUDA key" + assert mode is None, "Mode should never be enabled for CPU/CUDA key" if pred: return true_fn(*operands) else: @@ -125,8 +136,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands): flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands]) requires_grad = any( - isinstance(arg, torch.Tensor) and arg.requires_grad - for arg in flat_operands + isinstance(arg, torch.Tensor) and arg.requires_grad for arg in flat_operands ) with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)): @@ -148,6 +158,7 @@ def fake_requires_grad(var): var = var.detach() var.requires_grad = True return var + return err_fn(fake_requires_grad(result)) return result @@ -156,7 +167,7 @@ def fake_requires_grad(var): @cond.py_impl(ProxyTorchDispatchMode) def inner(pred, true_fn, false_fn, operands): mode = _get_current_dispatch_mode() - assert (mode is not None), "Mode should always be enabled for python fallback key" + assert mode is not None, "Mode should always be enabled for python fallback key" with _pop_mode_temporarily() as mode: if mode.enable_tracing: return trace_cond(mode, cond, pred, true_fn, false_fn, operands) @@ -177,7 +188,8 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands): false_meta = _extract_tensor_metadata(false_out) if true_meta != false_meta: raise RuntimeError( - f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}") + f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}" + ) return true_outs @@ -203,7 +215,10 @@ def _detect_input_mutation(gm): input_nodes.add(node) if node.op == "call_function": target = node.target - if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: + if ( + isinstance(target, torch._ops.OpOverload) + and target._schema.is_mutable + ): for arg in node.args: if arg in input_nodes: return True @@ -241,13 +256,15 @@ def _detect_input_alias(gm): # for map operator, where num_mapped_args is a scalar # and doesn't have a "val" meta. if node.op == "placeholder" and "val" in node.meta: - input_storages.add(StorageWeakRef(node.meta['val']._typed_storage())) + input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) if node.op == "output": + def check_alias(out): if out is not None and "val" in out.meta: - out_storage = StorageWeakRef(out.meta['val']._typed_storage()) + out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) return out_storage in input_storages return False + if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]): return True @@ -263,22 +280,30 @@ def check_alias(out): @cond.py_impl(DispatchKey.Functionalize) def cond_func(pred, true_fn, false_fn, inputs): reapply_views = torch._C._functionalization_reapply_views_tls() - unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) - unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) - mode = 'mutations_and_views' if reapply_views else 'mutations' + unwrapped_inputs = _unwrap_all_tensors_from_functional( + inputs, reapply_views=reapply_views + ) + unwrapped_pred = _unwrap_all_tensors_from_functional( + pred, reapply_views=reapply_views + ) + mode = "mutations_and_views" if reapply_views else "mutations" with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): functional_true = functionalize(true_fn, remove=mode) functional_false = functionalize(false_fn, remove=mode) for branch in [true_fn, false_fn]: if _has_potential_branch_input_mutation(branch, unwrapped_inputs): - raise UnsupportedAliasMutationException("One of torch.cond branch " - "might be modifying the input!") + raise UnsupportedAliasMutationException( + "One of torch.cond branch " "might be modifying the input!" + ) if _has_potential_branch_input_alias(branch, unwrapped_inputs): - raise UnsupportedAliasMutationException("One of torch.cond branch " - "might be aliasing the input!") + raise UnsupportedAliasMutationException( + "One of torch.cond branch " "might be aliasing the input!" + ) - cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs) + cond_return = cond( + unwrapped_pred, functional_true, functional_false, unwrapped_inputs + ) return _wrap_all_tensors_to_functional(cond_return, level=0) @@ -290,10 +315,14 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs): 2. Our check for above condition is not exhaustive """ reapply_views = interpreter.functionalize_add_back_views() - mode = 'mutations_and_views' if reapply_views else 'mutations' + mode = "mutations_and_views" if reapply_views else "mutations" # At this point, we will see functionalized tensors, so need to unwrap them first - unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) - unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) + unwrapped_inputs = _unwrap_all_tensors_from_functional( + inputs, reapply_views=reapply_views + ) + unwrapped_pred = _unwrap_all_tensors_from_functional( + pred, reapply_views=reapply_views + ) functional_true_fn = functionalize(true_fn, remove=mode) functional_false_fn = functionalize(false_fn, remove=mode) @@ -301,16 +330,21 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs): with interpreter.lower(): for branch in [functional_true_fn, functional_false_fn]: if _has_potential_branch_input_mutation(branch, unwrapped_inputs): - raise UnsupportedAliasMutationException("One of torch.cond branch " - "might be modifying the input!") + raise UnsupportedAliasMutationException( + "One of torch.cond branch " "might be modifying the input!" + ) for branch in [true_fn, false_fn]: if _has_potential_branch_input_alias(branch, unwrapped_inputs): - raise UnsupportedAliasMutationException("One of torch.cond branch " - "might be aliasing the input!") + raise UnsupportedAliasMutationException( + "One of torch.cond branch " "might be aliasing the input!" + ) - cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs) + cond_return = cond( + unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs + ) return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level()) + # TODO(voz): Make this automatic for keys, this is very ugly atm cond.fallthrough(DispatchKey.PythonDispatcher) cond.fallthrough(DispatchKey.PythonTLSSnapshot) diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py index d9add6f3ad15a6..38017a9eb7d2e6 100644 --- a/functorch/experimental/_map.py +++ b/functorch/experimental/_map.py @@ -1,23 +1,32 @@ import torch import torch.utils._pytree as pytree -from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard -from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize -from torch._functorch.aot_autograd import create_joint, AOTConfig +from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet +from torch._dispatch.python import suspend_functionalization +from torch._functorch.aot_autograd import AOTConfig, create_joint +from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + _wrap_all_tensors_to_functional, + functionalize, +) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.multiprocessing.reductions import StorageWeakRef from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) +from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( _get_current_dispatch_mode, _pop_mode_temporarily, ) -from torch._dispatch.python import suspend_functionalization -from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException + +from ._cond import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + UnsupportedAliasMutationException, +) # TODO: We add this to prevent dymamo from tracing into map_wrapper, @@ -26,16 +35,19 @@ class MapWrapper(HigherOrderOperator): def __call__(self, xs, *args): return map_wrapper(xs, *args) + map = MapWrapper("map", _deprecated_global_ns=True) map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True) -dummy_aot_config = AOTConfig(fw_compiler=None, - bw_compiler=None, - partition_fn=None, - decompositions={}, - num_params_buffers=0, - aot_id=0, - keep_inference_input_mutations=False) +dummy_aot_config = AOTConfig( + fw_compiler=None, + bw_compiler=None, + partition_fn=None, + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, +) def create_fw_bw_graph(f, num_mapped_args, *args): @@ -59,20 +71,33 @@ def create_fw_bw_graph(f, num_mapped_args, *args): with suspend_functionalization(): with disable_proxy_modes_tracing(): + def from_fun(t): if isinstance(t, torch.Tensor): - return torch.empty_strided(t.size(), t.stride(), requires_grad=t.requires_grad) + return torch.empty_strided( + t.size(), t.stride(), requires_grad=t.requires_grad + ) return t example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]] - example_pos_args = [from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args] - example_flat_out = pytree.tree_map(from_fun, f(*example_xs, *example_pos_args)) - if any(not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None): - raise RuntimeError("Expect outputs of map only contains tensors or None. " - f"Got types {[type(out) for out in example_flat_out]}.") + example_pos_args = [ + from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + example_flat_out = pytree.tree_map( + from_fun, f(*example_xs, *example_pos_args) + ) + if any( + not isinstance(out, torch.Tensor) + for out in example_flat_out + if out is not None + ): + raise RuntimeError( + "Expect outputs of map only contains tensors or None. " + f"Got types {[type(out) for out in example_flat_out]}." + ) example_grad = [from_fun(out) for out in example_flat_out] - fw_graph = make_fx(f)(*example_xs, *example_pos_args) def joint_f(*example_args): @@ -84,20 +109,39 @@ def joint_f(*example_args): def fw_with_masks(*args): fw_out = f(*args) - return fw_out, [True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out] + return fw_out, [ + True + if isinstance(ret, torch.Tensor) and ret.requires_grad + else False + for ret in fw_out + ] joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) - _, grads = joint(list(mapped_input) + list(args), - [grad for grad in mapped_grads if grad is not None and grad.requires_grad]) + _, grads = joint( + list(mapped_input) + list(args), + [ + grad + for grad in mapped_grads + if grad is not None and grad.requires_grad + ], + ) # In order to keep map functional for backward graph, # we clone outputs that are aliasing inputs - input_storage = {StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor)} + input_storage = { + StorageWeakRef(arg._typed_storage()) + for arg in example_args + if isinstance(arg, torch.Tensor) + } def maybe_clone(t): - if isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage: + if ( + isinstance(t, torch.Tensor) + and StorageWeakRef(t._typed_storage()) in input_storage + ): return t.clone() return t + return pytree.tree_map(maybe_clone, grads) joint_num_mapped = len(example_grad) + len(example_xs) @@ -114,12 +158,12 @@ def map_wrapper(f, xs, *args): shapes = [xs.shape for xs in flat_xs] leading_dim_size = shapes[0][0] if leading_dim_size == 0: - raise RuntimeError( - "Leading dimensions of mapped xs cannot be 0.") + raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): raise RuntimeError( - f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}.") + f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." + ) out_spec = None @@ -131,7 +175,11 @@ def flat_fn(*flat_args): nonlocal out_spec out_spec = tmp_out_spec return flat_out - return pytree.tree_unflatten(map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec) + + return pytree.tree_unflatten( + map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec + ) + class MapAutogradOp(torch.autograd.Function): @staticmethod @@ -140,17 +188,24 @@ def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): ctx._joint_graph = joint_graph ctx._num_mapped_args = num_mapped_args with torch._C._AutoDispatchBelowAutograd(): - return (*map_impl(fw_graph, num_mapped_args, *flat_args), ) + return (*map_impl(fw_graph, num_mapped_args, *flat_args),) @staticmethod def backward(ctx, *flat_grads): fw_args = ctx.saved_tensors - fw_mapped_args = fw_args[:ctx._num_mapped_args] - pos_args = fw_args[ctx._num_mapped_args:] - - grads = map_impl(ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args) + fw_mapped_args = fw_args[: ctx._num_mapped_args] + pos_args = fw_args[ctx._num_mapped_args :] + + grads = map_impl( + ctx._joint_graph, + ctx._num_mapped_args + len(flat_grads), + *fw_mapped_args, + *flat_grads, + *pos_args, + ) return None, None, None, *grads + def trace_map(proxy_mode, func_overload, f, num_mapped, *args): xs = list(args[:num_mapped]) pos_args = list(args[num_mapped:]) @@ -168,6 +223,7 @@ def expand_tensor(t): if isinstance(t, torch.Tensor): return t.expand(leading_dim_size, *t.shape) return t + expanded_outs = pytree.tree_map(expand_tensor, example_outs) next_name = None @@ -182,9 +238,13 @@ def expand_tensor(t): proxy_mode.tracer.root.register_module(next_name, body_graph) node_args = (body_graph, num_mapped, *args) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) - out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, - name="map_impl") - return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="map_impl" + ) + return track_tensor_tree( + expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + def _unstack_pytree(xs): flat_xs, inspec = pytree.tree_flatten(xs) @@ -192,7 +252,9 @@ def _unstack_pytree(xs): raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): - raise RuntimeError(f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}") + raise RuntimeError( + f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" + ) a = zip(*flat_xs) pytrees = [] @@ -200,6 +262,7 @@ def _unstack_pytree(xs): pytrees.append(pytree.tree_unflatten(tuple, inspec)) return pytrees + def _stack_pytree(pytrees): flat_out = [] out_spec = None @@ -220,6 +283,7 @@ def _stack_pytree(pytrees): raise RuntimeError(f"Cannot stack {leaves}.") return pytree.tree_unflatten(stacked_out, out_spec) + @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) def map_dense(f, num_mapped_args, *args): xs = args[:num_mapped_args] @@ -240,7 +304,7 @@ def map_autograd(f, num_mapped_args, *args): @map_impl.py_impl(ProxyTorchDispatchMode) def map_proxy_torch_dispatch_mode(f, num_mapped, *args): mode = _get_current_dispatch_mode() - assert (mode is not None), "Mode should always be enabled for python fallback key" + assert mode is not None, "Mode should always be enabled for python fallback key" with _pop_mode_temporarily() as mode: if mode.enable_tracing: return trace_map(mode, map_impl, f, num_mapped, *args) @@ -259,8 +323,10 @@ def map_func(f, num_mapped, *args): xs = args[:num_mapped] pos_args = args[num_mapped:] unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) - unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views) - mode = 'mutations_and_views' if reapply_views else 'mutations' + unwrapped_args = _unwrap_all_tensors_from_functional( + pos_args, reapply_views=reapply_views + ) + mode = "mutations_and_views" if reapply_views else "mutations" with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): functional_map_fn = functionalize(f, remove=mode) @@ -268,18 +334,17 @@ def map_func(f, num_mapped, *args): example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) if _has_potential_branch_input_mutation(f, example_inputs): - raise UnsupportedAliasMutationException( - "torch.map is mutating the input!" - ) + raise UnsupportedAliasMutationException("torch.map is mutating the input!") if _has_potential_branch_input_alias(f, example_inputs): - raise UnsupportedAliasMutationException( - "torch.map is aliasing the input!" - ) + raise UnsupportedAliasMutationException("torch.map is aliasing the input!") - map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args) + map_return = map_impl( + functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args + ) return _wrap_all_tensors_to_functional(map_return, level=0) + @map_impl.py_impl(torch._C._functorch.TransformType.Functionalize) def map_functionalize(interpreter, f, num_mapped, *args): """ @@ -290,10 +355,12 @@ def map_functionalize(interpreter, f, num_mapped, *args): xs = args[:num_mapped] pos_args = args[num_mapped:] reapply_views = interpreter.functionalize_add_back_views() - mode = 'mutations_and_views' if reapply_views else 'mutations' + mode = "mutations_and_views" if reapply_views else "mutations" # At this point, we will see functionalized tensors, so need to unwrap them first unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) - unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views) + unwrapped_args = _unwrap_all_tensors_from_functional( + pos_args, reapply_views=reapply_views + ) functional_map_fn = functionalize(f, remove=mode) @@ -301,18 +368,17 @@ def map_functionalize(interpreter, f, num_mapped, *args): with disable_proxy_modes_tracing(): example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) if _has_potential_branch_input_mutation(f, example_inputs): - raise UnsupportedAliasMutationException( - "torch.map is mutating the input!" - ) + raise UnsupportedAliasMutationException("torch.map is mutating the input!") if _has_potential_branch_input_alias(f, example_inputs): - raise UnsupportedAliasMutationException( - "torch.map is aliasing the input!" - ) + raise UnsupportedAliasMutationException("torch.map is aliasing the input!") - map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args) + map_return = map_impl( + functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args + ) return _wrap_all_tensors_to_functional(map_return, level=interpreter.level()) + # TODO(voz) Make this automatic for keys, this is very ugly atm map_impl.fallthrough(DispatchKey.PythonDispatcher) map_impl.fallthrough(DispatchKey.PythonTLSSnapshot) diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py index 5d42598c757aa0..095ada0c5152d7 100644 --- a/functorch/experimental/control_flow.py +++ b/functorch/experimental/control_flow.py @@ -1,2 +1,2 @@ -from ._map import map # noqa: F401 from ._cond import cond, UnsupportedAliasMutationException # noqa: F401 +from ._map import map # noqa: F401 diff --git a/functorch/notebooks/_src/plot_ensembling.py b/functorch/notebooks/_src/plot_ensembling.py index 7bce421ddfd6d3..f3627e425031b5 100644 --- a/functorch/notebooks/_src/plot_ensembling.py +++ b/functorch/notebooks/_src/plot_ensembling.py @@ -19,8 +19,10 @@ import torch import torch.nn as nn import torch.nn.functional as F + torch.manual_seed(0) + # Here's a simple CNN class SimpleCNN(nn.Module): def __init__(self): @@ -44,11 +46,12 @@ def forward(self, x): output = x return output + # Let's generate some dummy data. Pretend that we're working with an MNIST dataset # where the images are 28 by 28. # Furthermore, let's say we wish to combine the predictions from 10 different # models. -device = 'cuda' +device = "cuda" num_models = 10 data = torch.randn(100, 64, 1, 28, 28, device=device) targets = torch.randint(10, (6400,), device=device) @@ -81,6 +84,7 @@ def forward(self, x): # functorch offers the following convenience function to do that. It returns a # stateless version of the model (fmodel) and stacked parameters and buffers. from functorch import combine_state_for_ensemble + fmodel, params, buffers = combine_state_for_ensemble(models) [p.requires_grad_() for p in params] @@ -92,15 +96,20 @@ def forward(self, x): print([p.size(0) for p in params]) assert minibatches.shape == (num_models, 64, 1, 28, 28) from functorch import vmap + predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) -assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6) +assert torch.allclose( + predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6 +) # Option 2: get predictions using the same minibatch of data # vmap has an in_dims arg that specify which dimensions to map over. # Using ``None``, we tell vmap we want the same minibatch to apply for all of # the 10 models. predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) -assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6) +assert torch.allclose( + predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6 +) # A quick note: there are limitations around what types of functions can be # transformed by vmap. The best functions to transform are ones that are diff --git a/functorch/notebooks/_src/plot_jacobians_and_hessians.py b/functorch/notebooks/_src/plot_jacobians_and_hessians.py index 99db81556830d3..ca6e160bad25b3 100644 --- a/functorch/notebooks/_src/plot_jacobians_and_hessians.py +++ b/functorch/notebooks/_src/plot_jacobians_and_hessians.py @@ -8,11 +8,14 @@ efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently. """ +from functools import partial + import torch import torch.nn.functional as F -from functools import partial + torch.manual_seed(0) + ###################################################################### # Setup: Comparing functorch vs the naive approach # -------------------------------------------------------------------- @@ -21,6 +24,7 @@ def predict(weight, bias, x): return F.linear(x, weight, bias).tanh() + # Here's some dummy data: a weight, a bias, and a feature vector. D = 16 weight = torch.randn(D, D) @@ -34,19 +38,24 @@ def predict(weight, bias, x): xp = x.clone().requires_grad_() unit_vectors = torch.eye(D) + def compute_jac(xp): - jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] - for vec in unit_vectors] + jacobian_rows = [ + torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] + for vec in unit_vectors + ] return torch.stack(jacobian_rows) + jacobian = compute_jac(xp) # Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid # of the for-loop and vectorize the computation. We can't directly apply vmap # to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: -from functorch import vmap, vjp +from functorch import vjp, vmap + _, vjp_fn = vjp(partial(predict, weight, bias), x) -ft_jacobian, = vmap(vjp_fn)(unit_vectors) +(ft_jacobian,) = vmap(vjp_fn)(unit_vectors) assert torch.allclose(ft_jacobian, jacobian) # In another tutorial a composition of reverse-mode AD and vmap gave us @@ -59,6 +68,7 @@ def compute_jac(xp): # argument that says which argument we would like to compute Jacobians with # respect to. from functorch import jacrev + ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) assert torch.allclose(ft_jacobian, jacobian) @@ -67,6 +77,7 @@ def compute_jac(xp): # there are). In general, we expect that vectorization via ``vmap`` can help # eliminate overhead and give better utilization of your hardware. from torch.utils.benchmark import Timer + without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) print(without_vmap.timeit(500)) @@ -95,7 +106,7 @@ def compute_jac(xp): # In reverse-mode AD, we are computing the jacobian row-by-row, while in # forward-mode AD (which computes Jacobian-vector products), we are computing # it column-by-column. The Jacobian matrix has M rows and N columns. -from functorch import jacrev, jacfwd +from functorch import jacfwd, jacrev # Benchmark with more inputs than outputs Din = 32 @@ -106,8 +117,8 @@ def compute_jac(xp): using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) -print(f'jacfwd time: {using_fwd.timeit(500)}') -print(f'jacrev time: {using_bwd.timeit(500)}') +print(f"jacfwd time: {using_fwd.timeit(500)}") +print(f"jacrev time: {using_bwd.timeit(500)}") # Benchmark with more outputs than inputs Din = 2048 @@ -118,8 +129,8 @@ def compute_jac(xp): using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) -print(f'jacfwd time: {using_fwd.timeit(500)}') -print(f'jacrev time: {using_bwd.timeit(500)}') +print(f"jacfwd time: {using_fwd.timeit(500)}") +print(f"jacrev time: {using_bwd.timeit(500)}") ###################################################################### # Hessian computation with functorch.hessian @@ -132,6 +143,7 @@ def compute_jac(xp): # Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or # ``jacrev(jacrev(f))`` instead to compute hessians. from functorch import hessian + # # TODO: make sure PyTorch has tanh_backward implemented for jvp!! # hess0 = hessian(predict, argnums=2)(weight, bias, x) # hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) @@ -148,9 +160,11 @@ def compute_jac(xp): # The easiest way to do this is to sum over the batch dimension and then # compute the Jacobian of that function: + def predict_with_output_summed(weight, bias, x): return predict(weight, bias, x).sum(0) + batch_size = 64 Din = 31 Dout = 33 diff --git a/functorch/notebooks/_src/plot_per_sample_gradients.py b/functorch/notebooks/_src/plot_per_sample_gradients.py index 668e089f821c4b..b0d10bcf484c6b 100644 --- a/functorch/notebooks/_src/plot_per_sample_gradients.py +++ b/functorch/notebooks/_src/plot_per_sample_gradients.py @@ -12,8 +12,10 @@ import torch import torch.nn as nn import torch.nn.functional as F + torch.manual_seed(0) + # Here's a simple CNN class SimpleCNN(nn.Module): def __init__(self): @@ -37,12 +39,14 @@ def forward(self, x): output = x return output + def loss_fn(predictions, targets): return F.nll_loss(predictions, targets) + # Let's generate a batch of dummy data. Pretend that we're working with an # MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64. -device = 'cuda' +device = "cuda" num_models = 10 batch_size = 64 data = torch.randn(batch_size, 1, 28, 28, device=device) @@ -56,6 +60,7 @@ def loss_fn(predictions, targets): loss = loss_fn(predictions, targets) loss.backward() + # Conceptually, per-sample-gradient computation is equivalent to: for each sample # of the data, perform a forward and a backward pass to get a gradient. def compute_grad(sample, target): @@ -65,12 +70,14 @@ def compute_grad(sample, target): loss = loss_fn(prediction, target) return torch.autograd.grad(loss, list(model.parameters())) + def compute_sample_grads(data, targets): sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)] sample_grads = zip(*sample_grads) sample_grads = [torch.stack(shards) for shards in sample_grads] return sample_grads + per_sample_grads = compute_sample_grads(data, targets) # sample_grads[0] is the per-sample-grad for model.conv1.weight @@ -85,9 +92,11 @@ def compute_sample_grads(data, targets): # We can compute per-sample-gradients efficiently by using function transforms. # First, let's create a stateless functional version of ``model`` by using # ``functorch.make_functional_with_buffers``. -from functorch import make_functional_with_buffers, vmap, grad +from functorch import grad, make_functional_with_buffers, vmap + fmodel, params, buffers = make_functional_with_buffers(model) + # Next, let's define a function to compute the loss of the model given a single # input rather than a batch of inputs. It is important that this function accepts the # parameters, the input, and the target, because we will be transforming over them. @@ -100,6 +109,7 @@ def compute_loss(params, buffers, sample, target): loss = loss_fn(predictions, targets) return loss + # Now, let's use ``grad`` to create a new function that computes the gradient # with respect to the first argument of compute_loss (i.e. the params). ft_compute_grad = grad(compute_loss) diff --git a/functorch/op_analysis/gen_data.py b/functorch/op_analysis/gen_data.py index ab1f3a79125c20..a364a05f86a426 100644 --- a/functorch/op_analysis/gen_data.py +++ b/functorch/op_analysis/gen_data.py @@ -1,8 +1,9 @@ -import yaml import csv -import torch from collections import defaultdict +import torch +import yaml + def get_ops_for_key(key): # Needs modified PyTorch C++ code to work @@ -12,7 +13,7 @@ def get_ops_for_key(key): ops = torch._C._dispatch_get_registrations_for_dispatch_key(key) cleaned_ops = [] for i in ops: - if 'aten::' not in i: + if "aten::" not in i: continue cleaned_ops.append(i[6:].strip()) return set(cleaned_ops) @@ -20,12 +21,17 @@ def get_ops_for_key(key): def gen_data(special_op_lists, analysis_name): all_ops = get_ops_for_key(None) - composite_ops = get_ops_for_key('CompositeImplicitAutograd') + composite_ops = get_ops_for_key("CompositeImplicitAutograd") noncomposite_ops = all_ops - composite_ops - ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml').read(), Loader=yaml.CLoader) + ops = yaml.load( + open("../../aten/src/ATen/native/native_functions.yaml").read(), + Loader=yaml.CLoader, + ) - annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))} + annotated_ops = { + a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops"))) + } from collections import defaultdict uniq_ops = [] @@ -33,18 +39,18 @@ def gen_data(special_op_lists, analysis_name): overload_types = defaultdict(list) cnt = 0 for op in ops: - func_str = op['func'] - name = func_str[:func_str.index('(')] - if '.' in name: - uniq_name = name[:name.index('.')] - overload_types[name[name.index('.') + 1:]].append(name) + func_str = op["func"] + name = func_str[: func_str.index("(")] + if "." in name: + uniq_name = name[: name.index(".")] + overload_types[name[name.index(".") + 1 :]].append(name) else: uniq_name = name - op['name'] = uniq_name - full_name = func_str[:func_str.index('(')] - op['full_name'] = full_name - ret_type = func_str[func_str.index('->') + 3:] - op['ret_type'] = ret_type + op["name"] = uniq_name + full_name = func_str[: func_str.index("(")] + op["full_name"] = full_name + ret_type = func_str[func_str.index("->") + 3 :] + op["ret_type"] = ret_type cnt += 1 if uniq_name in uniq_names: continue @@ -54,104 +60,123 @@ def gen_data(special_op_lists, analysis_name): def annotate_ops(ops, is_unique): categorization = defaultdict(int) for op in ops: - if op['name'][-1] == '_': - categorization['inplace'] += 1 - op['meta'] = 'inplace' + if op["name"][-1] == "_": + categorization["inplace"] += 1 + op["meta"] = "inplace" continue - if not is_unique and 'a!' in op['func'].lower(): - categorization['out'] += 1 - op['meta'] = 'out' + if not is_unique and "a!" in op["func"].lower(): + categorization["out"] += 1 + op["meta"] = "out" continue - if 'conv' in op['name']: - categorization['conv'] += 1 - op['meta'] = 'conv' + if "conv" in op["name"]: + categorization["conv"] += 1 + op["meta"] = "conv" continue - if 'pool' in op['name']: - categorization['pool'] += 1 - op['meta'] = 'pool' + if "pool" in op["name"]: + categorization["pool"] += 1 + op["meta"] = "pool" continue - if 'backward' in op['name']: - categorization['backward'] += 1 - op['meta'] = 'backward' + if "backward" in op["name"]: + categorization["backward"] += 1 + op["meta"] = "backward" continue - if op['name'][0] == '_' and op['name'][1] != '_': - categorization['private'] += 1 - op['meta'] = 'private' + if op["name"][0] == "_" and op["name"][1] != "_": + categorization["private"] += 1 + op["meta"] = "private" continue - if 'batch_norm' in op['name']: - categorization['batch_norm'] += 1 - op['meta'] = 'batch_norm' + if "batch_norm" in op["name"]: + categorization["batch_norm"] += 1 + op["meta"] = "batch_norm" continue - if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']: - categorization['non_tensor'] += 1 - op['meta'] = 'non_tensor' + if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]: + categorization["non_tensor"] += 1 + op["meta"] = "non_tensor" continue - if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \ - 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']: - categorization['backend'] += 1 - op['meta'] = 'backend' + if ( + "cudnn" in op["name"] + or "mkldnn" in op["name"] + or "miopen" in op["name"] + or "native" in op["name"] + or "thnn" in op["name"] + or "slow" in op["name"] + ): + categorization["backend"] += 1 + op["meta"] = "backend" continue - if op['name'] in annotated_ops: - categorization['core'] += 1 - op['meta'] = 'core ' + annotated_ops[op['name']] + if op["name"] in annotated_ops: + categorization["core"] += 1 + op["meta"] = "core " + annotated_ops[op["name"]] continue - categorization['core'] += 1 - op['meta'] = 'core unknown' + categorization["core"] += 1 + op["meta"] = "core unknown" return categorization annotate_ops(ops, is_unique=False) - with open(f"{analysis_name}", 'w') as f: + with open(f"{analysis_name}", "w") as f: for op in ops: info = [ - op['full_name'], op['meta'], op['full_name'] not in noncomposite_ops + op["full_name"], + op["meta"], + op["full_name"] not in noncomposite_ops, ] + [check(op) for check in special_op_lists] - f.write(','.join([str(i) for i in info]) + '\n') + f.write(",".join([str(i) for i in info]) + "\n") def name_check(lst): - return lambda x: x['name'] in lst + return lambda x: x["name"] in lst def full_name_check(lst): - return lambda x: x['full_name'] in lst + return lambda x: x["full_name"] in lst # Generates batching rule data -gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt') +gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt") def remove_suffix(input_string, suffix): if suffix and input_string.endswith(suffix): - return input_string[:-len(suffix)] + return input_string[: -len(suffix)] return input_string + def remove_prefix(input_string, prefix): if prefix and input_string.startswith(prefix): - return input_string[len(prefix):] + return input_string[len(prefix) :] return input_string if True: - with open('run_ops.txt') as f: - opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] - with open('count_ops.txt') as f: + with open("run_ops.txt") as f: + opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()] + with open("count_ops.txt") as f: opinfo_counts = [i.strip() for i in f.readlines()] opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts))) def count_fn(x): - return opinfo_counts[x['full_name']] + return opinfo_counts[x["full_name"]] - with open('run_decompositions.txt') as f: - decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] + with open("run_decompositions.txt") as f: + decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()] - with open('public_api') as f: + with open("public_api") as f: ref_api = [i.strip() for i in f.readlines()] def has_ref_impl(x): - name = x['name'] + name = x["name"] for prefix in ["linalg_", "special_"]: name = remove_prefix(name, prefix) - prefixes = ['nn.functional', 'fft', 'special', 'linalg'] - return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api - - gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt') + prefixes = ["nn.functional", "fft", "special", "linalg"] + return ( + any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api + ) + + gen_data( + [ + full_name_check(opinfo_ops), + full_name_check(decomposed_ops), + count_fn, + has_ref_impl, + ], + "decompositions.txt", + )