From 02c4d877b458f6f6273be5dfe5ae088e286aae35 Mon Sep 17 00:00:00 2001 From: Antonio Kim Date: Tue, 24 May 2022 19:29:23 +0000 Subject: [PATCH] Codegen Non-Native IR Nodes (#76535) Add codegen infrastructure to generate IR nodes for non-native ops. The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g. ``` non_native: ... - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor ... ``` these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`. Fixes #74628 CC: @wconstab @desertfire @henrytwo Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535 Approved by: https://github.com/wconstab --- BUILD.bazel | 1 + aten/src/ATen/native/ts_native_functions.yaml | 38 ++++ aten/src/ATen/templates/LazyNonNativeIr.h | 11 ++ build.bzl | 2 + caffe2/CMakeLists.txt | 2 + tools/build_variables.bzl | 22 +-- tools/test/test_gen_backend_stubs.py | 2 +- torch/csrc/lazy/core/ir.cpp | 6 +- torch/csrc/lazy/core/ir.h | 1 + torch/csrc/lazy/core/ops/utils.h | 16 +- torch/csrc/lazy/core/permutation_util.h | 7 +- torch/csrc/lazy/core/shape_inference.cpp | 63 +++++++ torch/csrc/lazy/core/shape_inference.h | 22 +++ torch/csrc/lazy/ts_backend/ir_builder.h | 24 +-- torch/csrc/lazy/ts_backend/ops/cast.cpp | 42 ----- torch/csrc/lazy/ts_backend/ops/cast.h | 38 ---- torch/csrc/lazy/ts_backend/ops/expand.cpp | 29 --- torch/csrc/lazy/ts_backend/ops/expand.h | 37 ---- torch/csrc/lazy/ts_backend/ops/scalar.cpp | 40 ----- torch/csrc/lazy/ts_backend/ops/scalar.h | 35 ---- torch/csrc/lazy/ts_backend/ts_node.cpp | 30 ++-- .../csrc/lazy/ts_backend/ts_node_lowering.cpp | 63 +++---- .../lazy/ts_backend/view_ops/as_strided.cpp | 52 ------ .../lazy/ts_backend/view_ops/as_strided.h | 48 ----- .../view_ops/as_strided_view_update.cpp | 37 ---- .../view_ops/as_strided_view_update.h | 45 ----- .../lazy/ts_backend/view_ops/diagonal.cpp | 35 ---- .../csrc/lazy/ts_backend/view_ops/diagonal.h | 37 ---- .../view_ops/diagonal_view_update.cpp | 32 ---- .../view_ops/diagonal_view_update.h | 43 ----- .../csrc/lazy/ts_backend/view_ops/narrow.cpp | 33 ---- torch/csrc/lazy/ts_backend/view_ops/narrow.h | 35 ---- .../view_ops/narrow_view_update.cpp | 29 --- .../ts_backend/view_ops/narrow_view_update.h | 31 ---- .../csrc/lazy/ts_backend/view_ops/permute.cpp | 28 --- torch/csrc/lazy/ts_backend/view_ops/permute.h | 28 --- .../csrc/lazy/ts_backend/view_ops/resize.cpp | 29 --- torch/csrc/lazy/ts_backend/view_ops/resize.h | 27 --- .../csrc/lazy/ts_backend/view_ops/select.cpp | 44 ----- torch/csrc/lazy/ts_backend/view_ops/select.h | 49 ----- .../view_ops/select_view_update.cpp | 36 ---- .../ts_backend/view_ops/select_view_update.h | 49 ----- .../csrc/lazy/ts_backend/view_ops/squeeze.cpp | 28 --- torch/csrc/lazy/ts_backend/view_ops/squeeze.h | 26 --- .../lazy/ts_backend/view_ops/unsqueeze.cpp | 30 ---- .../csrc/lazy/ts_backend/view_ops/unsqueeze.h | 27 --- torch/csrc/lazy/ts_backend/view_ops/view.cpp | 35 ---- torch/csrc/lazy/ts_backend/view_ops/view.h | 29 --- torchgen/api/lazy.py | 119 ++++++++++-- torchgen/dest/__init__.py | 3 + torchgen/dest/lazy_ir.py | 169 +++++++++++++----- torchgen/dest/lazy_ts_lowering.py | 9 +- torchgen/gen_backend_stubs.py | 4 + torchgen/gen_lazy_tensor.py | 74 +++++--- torchgen/model.py | 14 +- 55 files changed, 497 insertions(+), 1348 deletions(-) create mode 100644 aten/src/ATen/templates/LazyNonNativeIr.h delete mode 100644 torch/csrc/lazy/ts_backend/ops/cast.cpp delete mode 100644 torch/csrc/lazy/ts_backend/ops/cast.h delete mode 100644 torch/csrc/lazy/ts_backend/ops/expand.cpp delete mode 100644 torch/csrc/lazy/ts_backend/ops/expand.h delete mode 100644 torch/csrc/lazy/ts_backend/ops/scalar.cpp delete mode 100644 torch/csrc/lazy/ts_backend/ops/scalar.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/as_strided.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/diagonal.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/narrow.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/narrow.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/permute.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/permute.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/resize.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/resize.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/select.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/select.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/select_view_update.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/squeeze.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/view.cpp delete mode 100644 torch/csrc/lazy/ts_backend/view_ops/view.h diff --git a/BUILD.bazel b/BUILD.bazel index d373a84f64d98..84bd6c53a5687 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1865,6 +1865,7 @@ test_suite( "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp", "aten/src/ATen/templates/DispatchKeyNativeFunctions.h", "aten/src/ATen/templates/LazyIr.h", + "aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml", diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index ba05aca4248e7..cba705bd0edd1 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -178,3 +178,41 @@ supported: - _unsafe_view autograd: - max_pool3d + +# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core +non_native: + - func: scalar(Scalar value, ScalarType type) -> Tensor + opkind: at::prim::Constant + properties: + - ShapeCompute + - TreatScalarsAsConstants + - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor + - func: view(Tensor input, int[] output_size) -> Tensor + properties: + - ShapeCompute + - func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor + opkind: ltc_cast + properties: + - ShapeCompute + + # View ops only required until proper functionalization pass is introduced into LTC + - func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor + opkind: ltc_as_strided_view_update + - func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor + - func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor + opkind: ltc_diagonal_view_update + properties: + - ShapeCompute + - func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor + - func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor + opkind: ltc_narrow_view_update + - func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor + - func: permute(Tensor input, int[] dims) -> Tensor + - func: resize(Tensor input, int[] size) -> Tensor + - func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor + opkind: ltc_select_view_update + properties: + - ShapeCompute + - func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor + - func: squeeze(Tensor input, int dim) -> Tensor + - func: unsqueeze(Tensor input, int dim) -> Tensor diff --git a/aten/src/ATen/templates/LazyNonNativeIr.h b/aten/src/ATen/templates/LazyNonNativeIr.h new file mode 100644 index 0000000000000..18eaf6da52e4b --- /dev/null +++ b/aten/src/ATen/templates/LazyNonNativeIr.h @@ -0,0 +1,11 @@ +#pragma once + +${lazy_non_native_ir_inc} + +// This file contains autogenerated LazyTensor Non Native IR nodes + +${namespace_prologue} + +${non_native_ir_nodes} + +${namespace_epilogue} diff --git a/build.bzl b/build.bzl index 191d3cc393b94..1678fd646b09f 100644 --- a/build.bzl +++ b/build.bzl @@ -28,6 +28,7 @@ def define_targets(rules): ":DispatchKeyNativeFunctions.cpp", ":DispatchKeyNativeFunctions.h", ":LazyIr.h", + ":LazyNonNativeIr.h", ":RegisterDispatchKey.cpp", ":native_functions.yaml", ":shape_inference.h", @@ -88,6 +89,7 @@ GENERATED_TESTING_PY = [ GENERATED_LAZY_H = [ "torch/csrc/lazy/generated/LazyIr.h", + "torch/csrc/lazy/generated/LazyNonNativeIr.h", "torch/csrc/lazy/generated/LazyNativeFunctions.h", ] diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4e9a90ef944d3..e21f0f34640c7 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -380,6 +380,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) list(APPEND GENERATED_H_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h" "${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.h" + "${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNonNativeIr.h" "${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.h" ) endif() @@ -444,6 +445,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h" "${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp" "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" + "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" "${TOOLS_PATH}/autograd/templates/VariableType.h" "${TOOLS_PATH}/autograd/templates/VariableType.cpp" diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 605bf82f48f1c..09f13d4128f32 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -417,35 +417,19 @@ lazy_tensor_core_sources = [ # We can't build all of the ts backend under certain build configurations, e.g. mobile, # since it depends on things like autograd, meta functions, which may be disabled lazy_tensor_ts_sources = [ - "torch/csrc/lazy/ts_backend/config.cpp", "torch/csrc/lazy/ts_backend/dynamic_ir.cpp", + "torch/csrc/lazy/ts_backend/config.cpp", "torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp", - "torch/csrc/lazy/ts_backend/ops/random_ops.cpp", - "torch/csrc/lazy/ts_backend/ops/cast.cpp", "torch/csrc/lazy/ts_backend/ops/device_data.cpp", - "torch/csrc/lazy/ts_backend/ops/expand.cpp", + "torch/csrc/lazy/ts_backend/ops/random_ops.cpp", "torch/csrc/lazy/ts_backend/ops/generic.cpp", - "torch/csrc/lazy/ts_backend/ops/scalar.cpp", - "torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp", - "torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp", - "torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp", - "torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp", - "torch/csrc/lazy/ts_backend/view_ops/narrow.cpp", - "torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp", - "torch/csrc/lazy/ts_backend/view_ops/permute.cpp", - "torch/csrc/lazy/ts_backend/view_ops/resize.cpp", - "torch/csrc/lazy/ts_backend/view_ops/select.cpp", - "torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp", - "torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp", - "torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp", - "torch/csrc/lazy/ts_backend/view_ops/view.cpp", - "torch/csrc/lazy/ts_backend/ts_node.cpp", "torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp", "torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp", "torch/csrc/lazy/ts_backend/ts_backend_impl.cpp", "torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp", "torch/csrc/lazy/ts_backend/ts_lowering_context.cpp", "torch/csrc/lazy/ts_backend/ts_native_functions.cpp", + "torch/csrc/lazy/ts_backend/ts_node.cpp", "torch/csrc/lazy/ts_backend/ts_node_lowering.cpp", ] diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 168ae8b1d7c73..dbe32ecec169b 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -237,7 +237,7 @@ def test_unrecognized_key(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen""", # noqa: B950 + """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native""", # noqa: B950 ) # if use_out_as_primary is provided, it must be a bool diff --git a/torch/csrc/lazy/core/ir.cpp b/torch/csrc/lazy/core/ir.cpp index ee11b3b47ff22..2a884b6a7c318 100644 --- a/torch/csrc/lazy/core/ir.cpp +++ b/torch/csrc/lazy/core/ir.cpp @@ -19,6 +19,10 @@ hash_t Output::hash() const { return HashCombine(node->hash(), Hash(index)); } +hash_t Output::shapeHash() const { + return HashCombine(node->shapeHash(), Hash(index)); +} + std::string Output::ToString() const { std::stringstream ss; ss << node->ToString() << ", index=" << index; @@ -144,7 +148,7 @@ std::string Node::ToString() const { void Node::AddOperand(NodePtr node, size_t index) { CHECK_LT(index, node->num_outputs()); - operands_.push_back(std::move(node)); + operands_.push_back(node); operands_as_outputs_.emplace_back(operands_.back().get(), index); } diff --git a/torch/csrc/lazy/core/ir.h b/torch/csrc/lazy/core/ir.h index 5ecbf5e5abe07..e583bfb2fe342 100644 --- a/torch/csrc/lazy/core/ir.h +++ b/torch/csrc/lazy/core/ir.h @@ -214,6 +214,7 @@ struct TORCH_API Output { : node(node), index(index) {} hash_t hash() const; + hash_t shapeHash() const; bool operator==(const Output& rhs) const { return node == rhs.node && index == rhs.index; diff --git a/torch/csrc/lazy/core/ops/utils.h b/torch/csrc/lazy/core/ops/utils.h index 9c902bb7244f9..15ba1d2642874 100644 --- a/torch/csrc/lazy/core/ops/utils.h +++ b/torch/csrc/lazy/core/ops/utils.h @@ -6,33 +6,33 @@ namespace torch { namespace lazy { -bool StrideIsSupported(c10::ArrayRef stride); +TORCH_API bool StrideIsSupported(c10::ArrayRef stride); -std::vector GetArrayStridePermutation(c10::ArrayRef stride); +TORCH_API std::vector GetArrayStridePermutation(c10::ArrayRef stride); -Shape MakeDiagonalShape( +TORCH_API Shape MakeDiagonalShape( const Shape& shape, int64_t offset, int64_t dim1, int64_t dim2); -Shape MakePermuteShape( +TORCH_API Shape MakePermuteShape( const Shape& source_shape, c10::ArrayRef permutation); -Shape MakeSelectShape( +TORCH_API Shape MakeSelectShape( const Shape& shape, int64_t dim, int64_t start, int64_t end, int64_t stride); -int64_t GetStride(int64_t start, int64_t end, int64_t stride); +TORCH_API int64_t GetStride(int64_t start, int64_t end, int64_t stride); -std::vector BuildSqueezedDimensions(c10::ArrayRef dimensions, +TORCH_API std::vector BuildSqueezedDimensions(c10::ArrayRef dimensions, int64_t squeeze_dim); -std::vector BuildUnsqueezedDimensions( +TORCH_API std::vector BuildUnsqueezedDimensions( c10::ArrayRef dimensions, int64_t squeeze_dim); diff --git a/torch/csrc/lazy/core/permutation_util.h b/torch/csrc/lazy/core/permutation_util.h index 06b932007195d..fd5a862d9160a 100644 --- a/torch/csrc/lazy/core/permutation_util.h +++ b/torch/csrc/lazy/core/permutation_util.h @@ -23,8 +23,11 @@ std::vector PermuteDimensions( const Container& dimensions) { using T = typename Container::value_type; TORCH_CHECK( - dimensions.size() == permutation.size() && IsPermutation(permutation), - "Invalid permutation specified"); + dimensions.size() == permutation.size(), + "Invalid permutation specified. dimensions.size() != permutation.size() (", dimensions.size(), " vs. ", permutation.size(), ")"); + TORCH_CHECK( + IsPermutation(permutation), + "Invalid permutation specified. Permutation is not permutation"); std::vector output(dimensions.size()); for (const auto i : c10::irange(permutation.size())) { output[i] = dimensions[permutation[i]]; diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index a922afde6c657..d5ee7fefc0bb9 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -45,10 +45,12 @@ #include +#include #include #include #include #include +#include #include #include #include @@ -629,6 +631,67 @@ std::vector compute_shape_narrow_copy(const at::Tensor & self, int64_t di return {Shape(self.scalar_type(), self.sizes().vec())}; } + +// Non-Native Ops +std::vector compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type) { + return { Shape(type, {}) }; +} +std::vector compute_shape_expand(const Output& input, const std::vector& size, const bool& is_scalar_expand) { + return { Shape(input.shape().scalar_type(), size) }; +} +std::vector compute_shape_view(const Output& input, const std::vector& output_sizes) { + const Shape& input_shape = input.shape(); + const auto complete_output_sizes = + at::infer_size(output_sizes, input_shape.numel()); + return { Shape(input_shape.scalar_type(), complete_output_sizes) }; +} +std::vector compute_shape_cast(const Output& input, const at::ScalarType& dtype, const c10::optional& stype) { + Shape shape = input.shape(); + shape.set_scalar_type(dtype); + return { shape }; +} + + +// View Ops +std::vector compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) { + return { Shape(target.shape().scalar_type(), size) }; +} +std::vector compute_shape_as_strided(const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) { + return { Shape(input.shape().scalar_type(), size) }; +} +std::vector compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) { + return { target.shape() }; +} +std::vector compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) { + return { MakeDiagonalShape(input.shape(), offset, dim1, dim2) }; +} +std::vector compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector& base_indices) { + return { input.shape() }; +} +std::vector compute_shape_narrow(const Output& input, const std::vector& base_indices, const std::vector& sizes) { + return { Shape(input.shape().scalar_type(), sizes) }; +} +std::vector compute_shape_permute(const Output& input, const std::vector& dims) { + return { MakePermuteShape(input.shape(), dims) }; +} +std::vector compute_shape_resize(const Output& input, const std::vector& size) { + return { Shape(input.shape().scalar_type(), size) }; +} +std::vector compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) { + return { target.shape() }; +} +std::vector compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) { + return { MakeSelectShape(input.shape(), dim, start, end, stride) }; +} +std::vector compute_shape_squeeze(const Output& input, const int& dim) { + const auto& input_shape = input.shape(); + return { torch::lazy::Shape(input_shape.scalar_type(), BuildSqueezedDimensions(input_shape.sizes(), dim)) }; +} +std::vector compute_shape_unsqueeze(const Output& input, const int& dim) { + const auto& input_shape = input.shape(); + return { torch::lazy::Shape(input_shape.scalar_type(), BuildUnsqueezedDimensions(input_shape.sizes(), dim)) }; +} + // Restore unused-parameters warnings #pragma GCC diagnostic pop diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index a2962c9a3e45f..adbe6193d57b9 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -68,5 +69,26 @@ TORCH_API std::vector compute_shape__to_copy(const at::Tenso TORCH_API std::vector compute_shape_trace(const at::Tensor & self); TORCH_API std::vector compute_shape_zero_functional(const at::Tensor & self); TORCH_API std::vector compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length); + +// Non-Native ops +TORCH_API std::vector compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type); +TORCH_API std::vector compute_shape_expand(const Output& input0, const std::vector& size, const bool& is_scalar_expand); +TORCH_API std::vector compute_shape_view(const Output& input0, const std::vector& output_sizes); +TORCH_API std::vector compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional& stype); + +// View Ops +TORCH_API std::vector compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset); +TORCH_API std::vector compute_shape_as_strided(const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset); +TORCH_API std::vector compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2); +TORCH_API std::vector compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2); +TORCH_API std::vector compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector& base_indices); +TORCH_API std::vector compute_shape_narrow(const Output& input, const std::vector& base_indices, const std::vector& sizes); +TORCH_API std::vector compute_shape_permute(const Output& input, const std::vector& dims); +TORCH_API std::vector compute_shape_resize(const Output& input, const std::vector& size); +TORCH_API std::vector compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride); +TORCH_API std::vector compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride); +TORCH_API std::vector compute_shape_squeeze(const Output& input, const int& dim); +TORCH_API std::vector compute_shape_unsqueeze(const Output& input, const int& dim); + } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ir_builder.h b/torch/csrc/lazy/ts_backend/ir_builder.h index dafd20178bf23..d1ebac0c7822b 100644 --- a/torch/csrc/lazy/ts_backend/ir_builder.h +++ b/torch/csrc/lazy/ts_backend/ir_builder.h @@ -4,31 +4,11 @@ #include #include #include +#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include -#include - -// This file contains the TorchScript IrBuilder +#include namespace torch { namespace lazy { diff --git a/torch/csrc/lazy/ts_backend/ops/cast.cpp b/torch/csrc/lazy/ts_backend/ops/cast.cpp deleted file mode 100644 index e1adcaa5c5356..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/cast.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -#include - -namespace torch { -namespace lazy { - -namespace { - -Shape NodeOutputShape(const Value& input, c10::ScalarType type) { - Shape shape = input.shape(); - shape.set_scalar_type(type); - return shape; -} - -} // namespace - -Cast::Cast( - const Value& input, - at::ScalarType dtype, - c10::optional stype) - : TsNode( - ClassOpKind(), - {input}, - {NodeOutputShape(input, dtype)}, - /*num_outputs=*/1, - MHash(101, static_cast(dtype), OptionalOr(stype, -1))), - dtype_(dtype), - stype_(stype) {} - -std::string Cast::ToString() const { - std::stringstream ss; - ss << TsNode::ToString(); - ss << ", dtype=" << dtype_; - if (stype_) { - ss << ", stype=" << *stype_; - } - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/cast.h b/torch/csrc/lazy/ts_backend/ops/cast.h deleted file mode 100644 index b94821aa3037f..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/cast.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace torch { -namespace lazy { - -class TORCH_API Cast : public TsNode { - public: - static OpKind ClassOpKind() { - return ltc_cast; - } - - Cast( - const Value& input, - at::ScalarType dtype, - c10::optional stype = c10::nullopt); - - std::string ToString() const override; - - at::ScalarType dtype() const { - return dtype_; - } - - const c10::optional& stype() const { - return stype_; - } - - private: - at::ScalarType dtype_; - c10::optional stype_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/expand.cpp b/torch/csrc/lazy/ts_backend/ops/expand.cpp deleted file mode 100644 index b7bf90631bfad..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/expand.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -namespace torch { -namespace lazy { - -Expand::Expand( - const Value& input, - std::vector size, - bool is_scalar_expand) - : TsNode( - ClassOpKind(), - {input}, - /*num_outputs=*/1, - MHash(size, is_scalar_expand)), - size_(std::move(size)), - is_scalar_expand_(is_scalar_expand) { - addComputedShape( - [&]() { return Shape(input.shape().scalar_type(), size_); }); -} - -std::string Expand::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_) - << "), is_scalar_expand=" << is_scalar_expand_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/expand.h b/torch/csrc/lazy/ts_backend/ops/expand.h deleted file mode 100644 index ceea8365b4295..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/expand.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Expand : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::expand); - } - - Expand(const Value& input, std::vector size, bool is_scalar_expand); - - std::string ToString() const override; - - const std::vector& size() const { - return size_; - } - - bool is_scalar_expand() const { - return is_scalar_expand_; - } - - private: - std::vector size_; - // True iff the input was a scalar and this was generated internally by a - // lowering and not by user action. For some backends, this difference can be - // material (for example setting strides according to eager semantics). - bool is_scalar_expand_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/scalar.cpp b/torch/csrc/lazy/ts_backend/ops/scalar.cpp deleted file mode 100644 index 114e8c26926d5..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/scalar.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -#include -#include - -#include - -namespace torch { -namespace lazy { - -using at::operator<<; - -Scalar::Scalar(const at::Scalar& value, Shape shape) - : TsNode( - ClassOpKind(), - std::move(shape), - /*num_outputs=*/1, - ScalarHash(value)), - value_(value) {} - -Scalar::Scalar(const at::Scalar& value, c10::ScalarType type) - : TsNode( - ClassOpKind(), - {Shape(type, {})}, - /*num_outputs=*/1, - ScalarHash(value)), - value_(value) {} - -std::string Scalar::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", value=" << value_; - return ss.str(); -} - -hash_t ScalarHash(const at::Scalar& s) { - return s.isFloatingPoint() ? Hash(s.toDouble()) : Hash(s.toLong()); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/scalar.h b/torch/csrc/lazy/ts_backend/ops/scalar.h deleted file mode 100644 index 1092735a3a569..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/scalar.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace lazy { - -// Differently from Constant, this is a scalar value broadcasted to a shape. -// Even though a Constant could have been used, for simple scalars broadcasted -// to big shapes, the Constant leads to big literals expanded within the -// computation graph. -class TORCH_API Scalar : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::prim::Constant); - } - - Scalar(const at::Scalar& value, Shape shape); - Scalar(const at::Scalar& value, c10::ScalarType type); - - std::string ToString() const override; - - const at::Scalar& value() const { - return value_; - } - - private: - at::Scalar value_; -}; - -TORCH_API hash_t ScalarHash(const at::Scalar& s); - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_node.cpp b/torch/csrc/lazy/ts_backend/ts_node.cpp index 51df344640551..a9ac4f8bfe263 100644 --- a/torch/csrc/lazy/ts_backend/ts_node.cpp +++ b/torch/csrc/lazy/ts_backend/ts_node.cpp @@ -17,29 +17,29 @@ namespace torch { namespace lazy { -hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes) { +hash_t OperandHashes(const OpList& operands, + const c10::ArrayRef& shapes, + const hash_t& seed, bool bakeInSizes) { hash_t hash = seed; for (auto& operand : operands) { if (!operand) { hash = HashCombine(hash, static_cast(kNullOpt)); continue; } - auto operand_hash = operand.hash(); + auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash(); hash = HashCombine(hash, operand_hash); } + for (auto& shape : shapes) { + hash = HashCombine(hash, shape.hash(bakeInSizes)); + } return hash; } -hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, bool bakeInSizes) { - hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes)); - return HashCombine(h, hash_seed); -} - TsNode::TsNode(OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, hash_t hash_seed) : Node(op, operands, std::move(shapes), num_outputs) { hash_seed = HashCombine(op.hash(), hash_seed); - shape_hash_ = OperandHashes(operands, hash_seed, true); - dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, hash_seed, false) : shape_hash_); + shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); + dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, this->shapes(), hash_seed, false) : shape_hash_); } @@ -53,11 +53,7 @@ TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) : TsNode(op, operands, std::vector{}, num_outputs, hash_seed) {} TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) - : Node(op, num_outputs), - shape_hash_(GetOpHash(op, shape, hash_seed, true)), - dag_hash_(enableDynamicShape() ? GetOpHash(op, shape, hash_seed, false) : shape_hash_) { - shapes_.push_back(std::move(shape)); -} + : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} hash_t TsNode::hash() const { return dag_hash_; } @@ -81,8 +77,8 @@ TensorList::TensorList(OpList values) : TsNode(/*op=*/ClassOpKind(), /*operands=*/values, /*shapes=*/std::vector(), - /*num_outputs=*/1, - /*hash_seed=*/OperandHashes(values, /*seed=*/kHashSeed, enableDynamicShape())) {} + /*num_outputs=*/1, + /*hash_seed=*/kHashSeed) {} TSOpVector TensorList::Lower(std::shared_ptr function, TSLoweringContext* loctx) const { @@ -97,5 +93,7 @@ TSOpVector TensorList::Lower(std::shared_ptr function return {listnode->output()}; } + + } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp b/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp index a43f78da6492e..57335d1a0a824 100644 --- a/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp +++ b/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp @@ -146,9 +146,9 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerAsStrided(const torch::lazy::AsStrided* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->size()); - arguments.emplace_back(node->stride()); - arguments.emplace_back(node->storage_offset()); + arguments.emplace_back(node->size); + arguments.emplace_back(node->stride); + arguments.emplace_back(node->storage_offset); TSOpVector as_strided_out = LowerBuiltin(node, arguments); CHECK_EQ(as_strided_out.size(), 1); return {GenerateClone(as_strided_out.front())}; @@ -165,8 +165,9 @@ class TSNodeLowering : public TSNodeLoweringInterface { dest_arguments.emplace_back(destination); dest_arguments.emplace_back( std::vector(input_dimensions.begin(), input_dimensions.end())); - dest_arguments.emplace_back(node->stride()); - dest_arguments.emplace_back(node->storage_offset()); + dest_arguments.emplace_back(node->stride); + dest_arguments.emplace_back(node->storage_offset + ); TSOpVector as_strided_out = LowerBuiltin(at::aten::as_strided, dest_arguments); CHECK_EQ(as_strided_out.size(), 1); @@ -209,16 +210,16 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerCast(const torch::lazy::Cast* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->dtype()); + arguments.emplace_back(node->dtype); return LowerBuiltin(at::aten::to, arguments); } TSOpVector LowerExpand(const torch::lazy::Expand* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->size()); + arguments.emplace_back(node->size); auto expand_out = LowerBuiltin(node, arguments); - if (node->is_scalar_expand()) { + if (node->is_scalar_expand) { // The aten::expand operations sets all strides to 0 when the original is // of rank 0. This leads to false positives when checking for internal // memory overlap, because at::has_internal_overlap returns @@ -232,8 +233,8 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerNarrow(const torch::lazy::Narrow* node) { const torch::lazy::Output& input = node->operand(0); torch::jit::Value* base = loctx()->GetOutputOp(input); - const auto& base_indices = node->base_indices(); - const auto& sizes = node->sizes(); + const auto& base_indices = node->base_indices; + const auto& sizes = node->sizes; const torch::lazy::Shape& input_shape = input.shape(); CHECK_EQ(sizes.size(), base_indices.size()); CHECK_EQ(input_shape.dim(), base_indices.size()); @@ -248,12 +249,12 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerPermute(const torch::lazy::Permute* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->dims()); + arguments.emplace_back(node->dims); return LowerBuiltin(node, arguments); } TSOpVector LowerScalar(const torch::lazy::Scalar* node) { - const at::Scalar& value = node->value(); + const at::Scalar& value = node->value; const torch::lazy::Shape& shape = node->shape(); auto options = at::TensorOptions() @@ -264,19 +265,19 @@ class TSNodeLowering : public TSNodeLoweringInterface { } TSOpVector LowerSelect(const torch::lazy::Select* node) { - int64_t step = torch::lazy::Select::GetStride(node->start(), node->end(), - node->stride()); + int64_t step = torch::lazy::GetStride(node->start, node->end, + node->stride); torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0)); - return {GenerateSlice(/*base=*/base, /*dim=*/node->dim(), - /*start=*/node->start(), /*end=*/node->end(), + return {GenerateSlice(/*base=*/base, /*dim=*/node->dim, + /*start=*/node->start, /*end=*/node->end, /*step=*/step)}; } TSOpVector LowerSqueeze(const Squeeze* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - if (node->dim() != -1) { - arguments.emplace_back(node->dim()); + if (node->dim != -1) { + arguments.emplace_back(node->dim); } return LowerBuiltin(node, arguments); } @@ -284,11 +285,11 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) { torch::jit::Value* dest = GenerateClone(loctx()->GetOutputOp(node->operand(0))); - int64_t step = torch::lazy::Select::GetStride(node->start(), node->end(), - node->stride()); + int64_t step = torch::lazy::GetStride(node->start, node->end, + node->stride); torch::jit::Value* selected = GenerateSlice( - /*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(), - /*end=*/node->end(), /*step=*/step); + /*base=*/dest, /*dim=*/node->dim, /*start=*/node->start, + /*end=*/node->end, /*step=*/step); GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1))); return {dest}; } @@ -296,7 +297,7 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) { torch::jit::Value* dest = GenerateClone(loctx()->GetOutputOp(node->operand(0))); - const auto& base_indices = node->base_indices(); + const auto& base_indices = node->base_indices; const torch::lazy::Output& source_argument = node->operand(1); const torch::lazy::Shape& source_shape = source_argument.shape(); CHECK_EQ(source_shape.dim(), base_indices.size()); @@ -314,23 +315,23 @@ class TSNodeLowering : public TSNodeLoweringInterface { TSOpVector LowerUnsqueeze(const Unsqueeze* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->dim()); + arguments.emplace_back(node->dim); return LowerBuiltin(node, arguments); } TSOpVector LowerView(const torch::lazy::View* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->output_size()); + arguments.emplace_back(node->output_size); return LowerBuiltin(at::aten::reshape, arguments); } TSOpVector LowerDiagonal(const Diagonal* node) { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); - arguments.emplace_back(node->offset()); - arguments.emplace_back(node->dim1()); - arguments.emplace_back(node->dim2()); + arguments.emplace_back(node->offset); + arguments.emplace_back(node->dim1); + arguments.emplace_back(node->dim2); return LowerBuiltin(node, arguments); } @@ -346,9 +347,9 @@ class TSNodeLowering : public TSNodeLoweringInterface { // Replay the diagonal. std::vector arguments; arguments.emplace_back(destination); - arguments.emplace_back(node->offset()); - arguments.emplace_back(node->dim1()); - arguments.emplace_back(node->dim2()); + arguments.emplace_back(node->offset); + arguments.emplace_back(node->dim1); + arguments.emplace_back(node->dim2); auto diag = LowerBuiltin(at::aten::diagonal, arguments); // Update the replayed diagonal view with the input. diff --git a/torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp b/torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp deleted file mode 100644 index eaabd23f9a2fe..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -#include - -#include -#include - -namespace torch { -namespace lazy { - -AsStrided::AsStrided( - const Value& input, - std::vector size, - std::vector stride, - int64_t storage_offset) - : TsNode( - ClassOpKind(), - {input}, - [&]() { - return Shape(input.shape().scalar_type(), size); - }, - /*num_outputs=*/1, - MHash(size, stride, storage_offset)), - size_(std::move(size)), - stride_(std::move(stride)), - storage_offset_(storage_offset) {} - -std::string AsStrided::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_) - << "), stride=(" << c10::Join(", ", stride_) - << "), storage_offset=" << storage_offset_; - return ss.str(); -} - -bool AsStrided::StrideIsSupported(c10::ArrayRef stride) { - std::vector sorted_stride(stride.begin(), stride.end()); - std::sort(sorted_stride.begin(), sorted_stride.end()); - return stride.empty() || sorted_stride.front() == 1; -} - -std::vector AsStrided::GetArrayStridePermutation( - c10::ArrayRef stride) { - std::vector permutation = Iota(stride.size()); - std::sort(permutation.begin(), permutation.end(), [&](int64_t a, int64_t b) { - return stride[a] > stride[b]; - }); - return permutation; -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/as_strided.h b/torch/csrc/lazy/ts_backend/view_ops/as_strided.h deleted file mode 100644 index 0c5cd49bc98e8..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/as_strided.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include - -#include - -namespace torch { -namespace lazy { - -class TORCH_API AsStrided : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::as_strided); - } - - AsStrided( - const Value& input, - std::vector size, - std::vector stride, - int64_t storage_offset); - - std::string ToString() const override; - - const std::vector& size() const { - return size_; - } - - const std::vector& stride() const { - return stride_; - } - - int64_t storage_offset() const { - return storage_offset_; - } - - static bool StrideIsSupported(c10::ArrayRef stride); - - static std::vector GetArrayStridePermutation( - c10::ArrayRef stride); - - private: - std::vector size_; - std::vector stride_; - int64_t storage_offset_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp b/torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp deleted file mode 100644 index 49b9e36b944c6..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include - -#include -#include -#include - -namespace torch { -namespace lazy { - -AsStridedViewUpdate::AsStridedViewUpdate( - const Value& target, - const Value& input, - std::vector size, - std::vector stride, - int64_t storage_offset) - : TsNode( - ltc_as_strided_view_update, - {target, input}, - [&]() { - return Shape(target.shape().scalar_type(), size); - }, - /*num_outputs=*/1, - MHash(size, stride, storage_offset)), - size_(std::move(size)), - stride_(std::move(stride)), - storage_offset_(storage_offset) {} - -std::string AsStridedViewUpdate::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_) - << "), stride=(" << c10::Join(", ", stride_) - << "), storage_offset=" << storage_offset_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h b/torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h deleted file mode 100644 index 02ecf31827890..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include - -#include -#include - -namespace torch { -namespace lazy { - -class TORCH_API AsStridedViewUpdate : public TsNode { - public: - static OpKind ClassOpKind() { - return ltc_as_strided_view_update; - } - - AsStridedViewUpdate( - const Value& target, - const Value& input, - std::vector size, - std::vector stride, - int64_t storage_offset); - - std::string ToString() const override; - - const std::vector& size() const { - return size_; - } - - const std::vector& stride() const { - return stride_; - } - - int64_t storage_offset() const { - return storage_offset_; - } - - private: - std::vector size_; - std::vector stride_; - int64_t storage_offset_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp b/torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp deleted file mode 100644 index 565796deafb1a..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include -#include -#include - -#include - -namespace torch { -namespace lazy { - -Diagonal::Diagonal( - const Value& input, - int64_t offset, - int64_t dim1, - int64_t dim2) - : TsNode( - ClassOpKind(), - {input}, - [&]() { - return MakeDiagonalShape(input.shape(), offset, dim1, dim2); - }, - /*num_outputs=*/1, - MHash(offset, dim1, dim2)), - offset_(offset), - dim1_(dim1), - dim2_(dim2) {} - -std::string Diagonal::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", offset=" << offset_ << ", dim1=" << dim1_ - << ", dim2=" << dim2_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/diagonal.h b/torch/csrc/lazy/ts_backend/view_ops/diagonal.h deleted file mode 100644 index 4ef669461b056..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/diagonal.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Diagonal : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::diagonal); - } - - Diagonal(const Value& input, int64_t offset, int64_t dim1, int64_t dim2); - - std::string ToString() const override; - - int64_t offset() const { - return offset_; - } - - int64_t dim1() const { - return dim1_; - } - - int64_t dim2() const { - return dim2_; - } - - private: - int64_t offset_; - int64_t dim1_; - int64_t dim2_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp b/torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp deleted file mode 100644 index 2d87832afa597..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include - -#include - -namespace torch { -namespace lazy { - -DiagonalViewUpdate::DiagonalViewUpdate( - const Value& target, - const Value& input, - int64_t offset, - int64_t dim1, - int64_t dim2) - : TsNode( - ltc_diagonal_view_update, - {target, input}, - {target.shape()}, - /*num_outputs=*/1, - MHash(offset, dim1, dim2)), - offset_(offset), - dim1_(dim1), - dim2_(dim2) {} - -std::string DiagonalViewUpdate::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", offset=" << offset_ << ", dim1=" << dim1_ - << ", dim2=" << dim2_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h b/torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h deleted file mode 100644 index 853abc8e6b2da..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace lazy { - -class TORCH_API DiagonalViewUpdate : public TsNode { - public: - static OpKind ClassOpKind() { - return ltc_diagonal_view_update; - } - - DiagonalViewUpdate( - const Value& target, - const Value& input, - int64_t offset, - int64_t dim1, - int64_t dim2); - - std::string ToString() const override; - - int64_t offset() const { - return offset_; - } - - int64_t dim1() const { - return dim1_; - } - - int64_t dim2() const { - return dim2_; - } - - private: - int64_t offset_; - int64_t dim1_; - int64_t dim2_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/narrow.cpp b/torch/csrc/lazy/ts_backend/view_ops/narrow.cpp deleted file mode 100644 index c5d26a4261ea1..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/narrow.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -#include - -namespace torch { -namespace lazy { - -Narrow::Narrow( - const Value& input, - c10::ArrayRef base_indices, - c10::ArrayRef sizes) - : TsNode( - ClassOpKind(), - {input}, - /*num_outputs=*/1, - MHash(base_indices, sizes)), - base_indices_(base_indices.begin(), base_indices.end()), - sizes_(sizes.begin(), sizes.end()) { - addComputedShape([&]() { - return Shape(operand(0).shape().scalar_type(), sizes); - }); -} - -std::string Narrow::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", base_indices=(" - << c10::Join(", ", base_indices_) << "), sizes=(" - << c10::Join(", ", sizes_) << ")"; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/narrow.h b/torch/csrc/lazy/ts_backend/view_ops/narrow.h deleted file mode 100644 index b82557558a18b..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/narrow.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Narrow : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::narrow); - } - - Narrow( - const Value& input, - c10::ArrayRef base_indices, - c10::ArrayRef sizes); - - std::string ToString() const override; - - const std::vector& base_indices() const { - return base_indices_; - } - - const std::vector& sizes() const { - return sizes_; - } - - private: - std::vector base_indices_; - std::vector sizes_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp b/torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp deleted file mode 100644 index db8c6340fa82a..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -#include - -namespace torch { -namespace lazy { - -NarrowViewUpdate::NarrowViewUpdate( - const Value& input, - const Value& source, - c10::ArrayRef base_indices) - : TsNode( - ltc_narrow_view_update, - {input, source}, - /*num_outputs=*/1, - MHash(base_indices)), - base_indices_(base_indices.begin(), base_indices.end()) { - addComputedShape([&]() { return operand(0).shape(); }); -} - -std::string NarrowViewUpdate::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", base_indices=(" - << c10::Join(", ", base_indices_) << ")"; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h b/torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h deleted file mode 100644 index c1c0189409b8b..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace lazy { - -class TORCH_API NarrowViewUpdate : public TsNode { - public: - static OpKind ClassOpKind() { - return ltc_narrow_view_update; - } - - NarrowViewUpdate( - const Value& input, - const Value& source, - c10::ArrayRef base_indices); - - std::string ToString() const override; - - const std::vector& base_indices() const { - return base_indices_; - } - - private: - std::vector base_indices_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/permute.cpp b/torch/csrc/lazy/ts_backend/view_ops/permute.cpp deleted file mode 100644 index ee9d932bc9c21..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/permute.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include -#include - -#include - -namespace torch { -namespace lazy { - -Permute::Permute(const Value& input, std::vector dims) - : TsNode( - ClassOpKind(), - {input}, - /*num_outputs=*/1, - MHash(dims)), - dims_(std::move(dims)) { - addComputedShape([&]() { - return MakePermuteShape(operand(0).shape(), dims_); - }); -} - -std::string Permute::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", dims=(" << c10::Join(", ", dims_) << ")"; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/permute.h b/torch/csrc/lazy/ts_backend/view_ops/permute.h deleted file mode 100644 index 23a07cd5d647f..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/permute.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Permute : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::permute); - } - - Permute(const Value& input, std::vector dims); - - std::string ToString() const override; - - const std::vector& dims() const { - return dims_; - } - - private: - // The permutation of dimensions. - std::vector dims_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/resize.cpp b/torch/csrc/lazy/ts_backend/view_ops/resize.cpp deleted file mode 100644 index 1a4e945c74f07..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/resize.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -namespace torch { -namespace lazy { - -namespace { -Shape NodeOutputShape(const Value& input, c10::ArrayRef size) { - return Shape(input.shape().scalar_type(), size); -} - -} // namespace - -Resize::Resize(const Value& input, std::vector size) - : TsNode( - ClassOpKind(), - {input}, - [&]() { return NodeOutputShape(input, size); }, - /*num_outputs=*/1, - MHash(size)), - size_(std::move(size)) {} - -std::string Resize::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", size=(" << c10::Join(", ", size_) << ")"; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/resize.h b/torch/csrc/lazy/ts_backend/view_ops/resize.h deleted file mode 100644 index e906c60f5e8e0..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/resize.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Resize : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::resize); - } - - Resize(const Value& input, std::vector size); - - std::string ToString() const override; - - const std::vector& size() const { - return size_; - } - - private: - std::vector size_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/select.cpp b/torch/csrc/lazy/ts_backend/view_ops/select.cpp deleted file mode 100644 index f27e1df652371..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/select.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include -#include - -#include - -namespace torch { -namespace lazy { - -Select::Select( - const Value& input, - int64_t dim, - int64_t start, - int64_t end, - int64_t stride) - : TsNode( - ClassOpKind(), - {input}, - [&]() { - return MakeSelectShape(input.shape(), dim, start, end, stride); - }, - /*num_outputs=*/1, - MHash(dim, start, end, stride)), - dim_(dim), - start_(start), - end_(end), - stride_(stride) {} - -std::string Select::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", dim=" << dim_ << ", start=" << start_ - << ", end=" << end_ << ", stride=" << stride_; - return ss.str(); -} - -int64_t Select::GetStride(int64_t start, int64_t end, int64_t stride) { - if (stride == 0) { - CHECK_EQ(start, end); - stride = 1; - } - return stride; -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/select.h b/torch/csrc/lazy/ts_backend/view_ops/select.h deleted file mode 100644 index f6bd0c81777f4..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/select.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Select : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::select); - } - - Select( - const Value& input, - int64_t dim, - int64_t start, - int64_t end, - int64_t stride); - - std::string ToString() const override; - - int64_t dim() const { - return dim_; - } - - int64_t start() const { - return start_; - } - - int64_t end() const { - return end_; - } - - int64_t stride() const { - return stride_; - } - - static int64_t GetStride(int64_t start, int64_t end, int64_t stride); - - private: - int64_t dim_; - int64_t start_; - int64_t end_; - int64_t stride_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp b/torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp deleted file mode 100644 index 20291f948ea46..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include - -#include -#include -#include - -namespace torch { -namespace lazy { - -SelectViewUpdate::SelectViewUpdate( - const Value& target, - const Value& source, - int64_t dim, - int64_t start, - int64_t end, - int64_t stride) - : TsNode( - ltc_select_view_update, - {target, source}, - {target.shape()}, - /*num_outputs=*/1, - MHash(dim, start, end, stride)), - dim_(dim), - start_(start), - end_(end), - stride_(stride) {} - -std::string SelectViewUpdate::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", dim=" << dim_ << ", start=" << start_ - << ", end=" << end_ << ", stride=" << stride_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/select_view_update.h b/torch/csrc/lazy/ts_backend/view_ops/select_view_update.h deleted file mode 100644 index 9b62fbd9c3d43..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/select_view_update.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace lazy { - -class TORCH_API SelectViewUpdate : public TsNode { - public: - static OpKind ClassOpKind() { - return ltc_select_view_update; - } - - SelectViewUpdate( - const Value& target, - const Value& source, - int64_t dim, - int64_t start, - int64_t end, - int64_t stride); - - std::string ToString() const override; - - int64_t dim() const { - return dim_; - } - - int64_t start() const { - return start_; - } - - int64_t end() const { - return end_; - } - - int64_t stride() const { - return stride_; - } - - private: - int64_t dim_; - int64_t start_; - int64_t end_; - int64_t stride_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp b/torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp deleted file mode 100644 index 6a40211732395..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include -#include -#include -#include - -namespace torch { -namespace lazy { - -Squeeze::Squeeze(const torch::lazy::Value& input, int dim) - : torch::lazy::TsNode(ClassOpKind(), {input}, - /*num_outputs=*/1, torch::lazy::MHash(dim)), - dim_(dim) { - addComputedShape( - [&]() { - const auto& input_shape = input.shape(); - return torch::lazy::Shape(input_shape.scalar_type(), - BuildSqueezedDimensions(input_shape.sizes(), dim)); - }); -} - -std::string Squeeze::ToString() const { - std::stringstream ss; - ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/squeeze.h b/torch/csrc/lazy/ts_backend/view_ops/squeeze.h deleted file mode 100644 index 2b0c7fb2211d4..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/squeeze.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Squeeze : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::squeeze); - } - - // Squeeze out the specified dimension index, -1 for all trivial dimensions. - Squeeze(const torch::lazy::Value& input, int dim); - - std::string ToString() const override; - - int dim() const { return dim_; } - - private: - int dim_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp b/torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp deleted file mode 100644 index 3b68810cf7f98..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include -#include - -namespace torch { -namespace lazy { - -Unsqueeze::Unsqueeze(const torch::lazy::Value& input, int dim) - : torch::lazy::TsNode( - ClassOpKind(), - {input}, - /*num_outputs=*/1, - torch::lazy::MHash(dim)), - dim_(dim) { - addComputedShape([&]() { - const auto& input_shape = input.shape(); - return torch::lazy::Shape( - input_shape.scalar_type(), - BuildUnsqueezedDimensions(input_shape.sizes(), dim)); - }); -} - -std::string Unsqueeze::ToString() const { - std::stringstream ss; - ss << torch::lazy::TsNode::ToString() << ", dim=" << dim_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h b/torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h deleted file mode 100644 index 9b561fff5e21c..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -class TORCH_API Unsqueeze : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::unsqueeze); - } - - Unsqueeze(const torch::lazy::Value& input, int dim); - - std::string ToString() const override; - - int dim() const { - return dim_; - } - - private: - int dim_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/view.cpp b/torch/csrc/lazy/ts_backend/view_ops/view.cpp deleted file mode 100644 index 2048593ed7372..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/view.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include - -#include - -namespace torch { -namespace lazy { - -namespace { -Shape NodeOutputShape(const Value& input, c10::ArrayRef output_sizes) { - const Shape& input_shape = input.shape(); - const auto complete_output_sizes = - at::infer_size(output_sizes, input_shape.numel()); - return Shape(input_shape.scalar_type(), complete_output_sizes); -} - -} // namespace - -View::View(const Value& input, std::vector output_size) - : TsNode( - ClassOpKind(), - {input}, - {NodeOutputShape(input, output_size)}, - /*num_outputs=*/1, - MHash(output_size)), - output_size_(std::move(output_size)) {} - -std::string View::ToString() const { - std::stringstream ss; - ss << TsNode::ToString() << ", output_size=(" << c10::Join(", ", output_size_) - << ")"; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/view_ops/view.h b/torch/csrc/lazy/ts_backend/view_ops/view.h deleted file mode 100644 index 5da9465c5d2b0..0000000000000 --- a/torch/csrc/lazy/ts_backend/view_ops/view.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include - -#include - -namespace torch { -namespace lazy { - -class TORCH_API View : public TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::view); - } - - View(const Value& input, std::vector output_size); - - std::string ToString() const override; - - const std::vector& output_size() const { - return output_size_; - } - - private: - std::vector output_size_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index ff74f4ab34bd8..d424ae02ecb4e 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -1,4 +1,5 @@ -from typing import List, Union, Tuple, Optional +from typing import Any, Dict, List, Union, Tuple, Optional + from torchgen.model import ( Type, BaseTy, @@ -56,7 +57,7 @@ def setValueT(val: BaseCppType) -> None: def process_ir_type( - typ: Type, + typ: Type, properties: "LazyIrProperties" ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with @@ -77,6 +78,8 @@ def process_ir_type( if typ.name == BaseTy.Tensor: return BaseCType(getValueT()) elif typ.name == BaseTy.Scalar: + if properties.TreatScalarsAsConstants: + return BaseCType(scalarT) # at::scalar has special handling, # and is wrapped in an lazy::Value just like at::tensor return BaseCType(getValueT()) @@ -101,7 +104,7 @@ def process_ir_type( else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): - return OptionalCType(process_ir_type(typ.elem)) + return OptionalCType(process_ir_type(typ.elem, properties)) elif isinstance(typ, ListType): if str(typ.elem) == "Tensor?": # TODO(whc) is this actually correct? or should it use a Vector like above @@ -110,12 +113,12 @@ def process_ir_type( # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) else: - return VectorCType(process_ir_type(typ.elem)) + return VectorCType(process_ir_type(typ.elem, properties)) else: raise AssertionError(f"unrecognized type {repr(typ)}") -def isValueType(typ: CType) -> bool: +def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool: """ Given a type, determine if it is a Value-like type. This is equivalent to being Tensor-like, but assumes the type has already been transformed. @@ -123,9 +126,14 @@ def isValueType(typ: CType) -> bool: if isinstance(typ, BaseCType): # I am regretting my naming conventions, but now we are wrapping at::scalar in # lazy value, while preserving other 'scalar' types as scalars in the IR - return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT + treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants + return ( + typ.type == getValueT() + or (typ.type == scalarT and not treat_scalars_as_constants) + or typ.type == SymIntT + ) elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): - return isValueType(typ.elem) + return isValueType(typ.elem, properties) return False @@ -167,7 +175,7 @@ class LazyArgument: # true if this argument is or contains a lazy IR value is_lazy_value: bool - def __init__(self, arg: Argument): + def __init__(self, arg: Argument, properties: "LazyIrProperties"): self.name = arg.name self.orig_type = arg.type self.is_optional = isinstance(arg.type, OptionalType) @@ -181,11 +189,13 @@ def __init__(self, arg: Argument): # its null and safe to exclude from lazy IR self.lazy_type_ = None else: - self.lazy_type_ = process_ir_type(arg.type) + self.lazy_type_ = process_ir_type(arg.type, properties) self.is_wrapped_scalar = isWrappedScalarType(arg.type) self.is_symint_or_list = isSymIntType(arg.type) - self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type) + self.is_lazy_value = not self.is_generator and isValueType( + self.lazy_type, properties + ) @property def lazy_type(self) -> CType: @@ -195,6 +205,64 @@ def lazy_type(self) -> CType: return self.lazy_type_ +class LazyIrProperties: + """Collection of properties for an IR node + + The property groups are listed below. Each group is mutually + exclusive, meaning that only one property from each group can be True + at any one time. The properties can be accessed as if they were normal + attributes. The mutual exclusivity is automatically handled. + """ + + Properties: Tuple[Tuple[str, ...], ...] = ( + ( + "ShapePrecompute", # Assume shape has been precomputed + "ShapeCompute", # Need to compute the shape on construction + "ShapeCache", # Utilize the shape cache to defer computation + ), + ( + "Lower", # Codegen full lower function + "LowerDeclOnly", # Codegen only lower function declaration + ), + ( + "CanBeReused", # Codegen full reuse function + "CanBeReusedDeclOnly", # Codegen only reuse function declaration + ), + ( + "CreateFn", # Codegen full create function + "CreateFnDeclOnly", # Codegen only create function declaration + ), + ( + "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values + ), + ) + + def __init__(self, *default_properties: str): + properties: Dict[Tuple[str, ...], Optional[str]] = { + p: None for p in LazyIrProperties.Properties + } + self.__dict__["properties"] = properties + for p in default_properties: + setattr(self, p, True) + + def __getattr__(self, key: str) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + return properties[values] == key + + return self.__getattribute__(key) + + def __setattr__(self, key: str, value: Any) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + properties[values] = key if value else None + return value + + raise KeyError(f"Invalid property: {key}") + + # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), # but carries type information from a native FunctionSchema modified for use with IR nodes, @@ -213,20 +281,33 @@ class LazyIrSchema: # build a LazyArgument since lazy IR doesn't support it generator_arg: Optional[NamedCType] = None - def __init__(self, func: FunctionSchema): - - positional_args = [] + properties: LazyIrProperties = LazyIrProperties( + # default properties + "ShapePrecompute", + "Lower", + "CanBeReused", + ) + opkind: Optional[str] = None + + def __init__( + self, func: FunctionSchema, properties: Optional[LazyIrProperties] = None + ): + if properties: + self.properties = properties + + positional_args: List[LazyArgument] = [] for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: if arg_field == "self_arg" and func.arguments.self_arg is not None: arg = getattr(func.arguments, "self_arg").argument - positional_args.append(LazyArgument(arg)) + positional_args.append(LazyArgument(arg, self.properties)) elif getattr(func.arguments, arg_field) is not None: positional_args.extend( - [LazyArgument(arg) for arg in getattr(func.arguments, arg_field)] + LazyArgument(arg, self.properties) + for arg in getattr(func.arguments, arg_field) ) self.positional_args = tuple(positional_args) - keyword_args = [] + keyword_args: List[LazyArgument] = [] for arg_field in [ "pre_tensor_options_kwarg_only", "tensor_options", @@ -243,7 +324,9 @@ def __init__(self, func: FunctionSchema): self.generator_arg is None ), "We expect there is only one generator arg" self.generator_arg = NamedCType(arg.name, arg.type) - keyword_args.extend([LazyArgument(arg) for arg in curr_args]) + keyword_args.extend( + LazyArgument(arg, self.properties) for arg in curr_args + ) self.keyword_args = tuple(keyword_args) self.name = func.name self.returns = func.returns @@ -262,7 +345,7 @@ def node_name(self) -> str: @property def aten_name(self) -> str: - return f"{self.name.name}" + return str(self.name.name) @property def base_name(self) -> str: diff --git a/torchgen/dest/__init__.py b/torchgen/dest/__init__.py index 2ac52939f63ae..498c437a88a34 100644 --- a/torchgen/dest/__init__.py +++ b/torchgen/dest/__init__.py @@ -1,6 +1,9 @@ from .lazy_ir import GenLazyIR as GenLazyIR from .lazy_ir import GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition from .lazy_ir import GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition +from .lazy_ir import ( + generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, +) from .register_dispatch_key import ( RegisterDispatchKey as RegisterDispatchKey, gen_registration_helpers as gen_registration_helpers, diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index 66c9e3d749eb5..16f4750632a08 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -1,8 +1,13 @@ from abc import ABC -from typing import List, Optional, Union from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union from torchgen.context import method_with_native_function -from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup +from torchgen.model import ( + BackendIndex, + NativeFunction, + NativeFunctionsGroup, + FunctionSchema, +) from torchgen.api.types import ( BaseCType, OptionalCType, @@ -12,6 +17,7 @@ ) import torchgen.api.dispatcher as dispatcher from torchgen.api.lazy import ( + LazyIrProperties, LazyIrSchema, LazyArgument, getValueT, @@ -108,36 +114,44 @@ def aten_symbol(schema: LazyIrSchema) -> str: } if schema.aten_name in missing_interned_strings: return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")' - return f"at::aten::{schema.aten_name}" + + if not schema.aten_name.startswith("at::"): + return f"at::aten::{schema.aten_name}" + else: + return schema.aten_name @dataclass(frozen=True) class GenLazyIR(ABC): backend_index: BackendIndex + backend_name: str node_base: str @method_with_native_function def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func - return self.gen(f) + schema = LazyIrSchema(func) + return self.gen(schema) # there is no lowering functionality generated unless this IR base class is subclassed and # implemented as a backend-specific node - def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str: + def lowering_function(self, schema: LazyIrSchema) -> str: return "" - def can_be_reused_function( - self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str - ) -> str: + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return "" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: return f"""bool CanBeReused({node_ctor_args}) const {{ return false; }}""" def node_base_ctor_call(self, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) # backends can customize the way the node base class constructor is called, # as long as all of its arguments can be generated from information available from the schema base_ctor_value_args_list = [] - for arg in schema.filtered_args(values=True, scalars=False): + for arg in value_args: if isinstance(arg.lazy_type, BaseCType) or isinstance( arg.lazy_type, VectorCType ): @@ -151,29 +165,51 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str: base_ctor_value_args = ", ".join(base_ctor_value_args_list) scalar_args = schema.filtered_args(values=False, scalars=True) - scalar_hashes = ", ".join([f"{a.name}" for a in scalar_args]) - return f"""{self.node_base}(torch::lazy::OpKind({aten_symbol(schema)}), - {{{base_ctor_value_args}}}, std::move(shapes), + # Shape constuction. + # Conditionally build shape depending on specified shape property + if schema.properties.ShapePrecompute: + shape_ctor_arg = "std::move(shapes)," + elif schema.properties.ShapeCompute: + shape_args = [a.name for a in value_args] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)})," + elif schema.properties.ShapeCache: + shape_args = [f"operand({i})" for i in range(len(value_args))] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }}," + else: + shape_ctor_arg = "" + + scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args) + + return f"""{self.node_base}( + {schema.node_name}::ClassOpKind(), + OpList{{{base_ctor_value_args}}}, + {shape_ctor_arg} /* num_outputs */ {len(schema.returns)}, torch::lazy::MHash({scalar_hashes}))""" - def gen(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: + def gen(self, schema: LazyIrSchema) -> List[str]: + opkind = schema.opkind or aten_symbol(schema) + # for now, we just want one IR class decl and soon after also the method defs # and we use the functional version not out/inplace. - func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func - schema = LazyIrSchema(func) all_args = schema.filtered_args() value_args = schema.filtered_args(values=True, scalars=False) scalar_args = schema.filtered_args(values=False, scalars=True) - node_ctor_args = ", ".join( - [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] - ) + ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] + reuse_ctor_args = ", ".join(ctor_args) + if schema.properties.ShapePrecompute: + ctor_args.append("std::vector&& shapes") + node_ctor_args = ", ".join(ctor_args) + scalar_initializers = ",\n ".join( - [f"{a.name}({a.name})" for a in scalar_args] + f"{a.name}({a.name})" for a in scalar_args ) - comma_if_scalar_initializers = ",\n" if len(scalar_initializers) else "" + if len(scalar_initializers): + scalar_initializers = f",\n {scalar_initializers}" scalar_decls = "\n ".join( [ f"std::string {a.name};" @@ -212,14 +248,11 @@ def gen(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: class {schema.node_name} : public {self.node_base} {{ public: static torch::lazy::OpKind ClassOpKind() {{ - return torch::lazy::OpKind({aten_symbol(schema)}); + return torch::lazy::OpKind({opkind}); }} - {schema.node_name}({node_ctor_args}, std::vector&& shapes) - - : {self.node_base_ctor_call(schema)}{comma_if_scalar_initializers} - {scalar_initializers} - + {schema.node_name}({node_ctor_args}) + : {self.node_base_ctor_call(schema)}{scalar_initializers} {{ {has_optional_defs} }} @@ -231,9 +264,11 @@ class {schema.node_name} : public {self.node_base} {{ return ss.str(); }} - {self.can_be_reused_function(f, node_ctor_args)} + {self.create_function(schema, reuse_ctor_args)} + + {self.can_be_reused_function(schema, reuse_ctor_args)} - {self.lowering_function(f)} + {self.lowering_function(schema)} {scalar_decls} {has_optional_decls} @@ -246,37 +281,57 @@ class {schema.node_name} : public {self.node_base} {{ @dataclass(frozen=True) class GenTSLazyIR(GenLazyIR): - def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str: - return f"""torch::lazy::TSOpVector Lower(std::shared_ptr function, - torch::lazy::TSLoweringContext* loctx) const override {{ - {ts_lowering_body(f)} + def lowering_function(self, schema: LazyIrSchema) -> str: + signature = """ + torch::lazy::TSOpVector Lower( + std::shared_ptr function, + torch::lazy::TSLoweringContext* loctx) const override""" + + if schema.properties.LowerDeclOnly: + return f"{signature};" + elif schema.properties.Lower: + return f"""{signature} {{ + {ts_lowering_body(schema)} + }} + """ + else: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"static NodePtr Create({node_ctor_args})" + if schema.properties.CreateFnDeclOnly: + return f"{signature};" + elif not schema.properties.CreateFn: + return "" + return f"""{signature} {{ + return ReuseOrMakeNode<{schema.node_name}>(data); }}""" - def can_be_reused_function( - self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str - ) -> str: - func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func - schema = LazyIrSchema(func) - - value_comparsion = [] + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"bool CanBeReused({node_ctor_args}) const" + if schema.properties.CanBeReusedDeclOnly: + return f"{signature};" + elif not schema.properties.CanBeReused: + return "" + value_comparison = [] for arg in schema.positional_values: if isinstance(arg.lazy_type, OptionalCType): - value_comparsion.append( + value_comparison.append( f"operand(i++) == {arg.name}.value_or(kNullValue)" ) else: - value_comparsion.append(f"operand(i++) == {arg.name}") + value_comparison.append(f"operand(i++) == {arg.name}") for arg in schema.positional_scalars: - value_comparsion.append(f"this->{arg.name} == {arg.name}") + value_comparison.append(f"this->{arg.name} == {arg.name}") for arg in schema.keyword_values: - value_comparsion.append(f"operand(i++) == {arg.name}") + value_comparison.append(f"operand(i++) == {arg.name}") for arg in schema.keyword_scalars: - value_comparsion.append(f"this->{arg.name} == {arg.name}") - value_comparsion_str = " &&\n ".join(value_comparsion) + value_comparison.append(f"this->{arg.name} == {arg.name}") + value_comparison_str = " &&\n ".join(value_comparison) - return f"""bool CanBeReused({node_ctor_args}) const {{ + return f"""{signature} {{ size_t i = 0; - return ({value_comparsion_str}); + return ({value_comparison_str}); }}""" @@ -399,7 +454,7 @@ def this_shape(i: int) -> str: shape_str += f""" if(torch::lazy::symbolicShapeEnabled()){{ std::vector inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; - char* schema_str = "{func_schema_str}"; + const char* schema_str = "{func_schema_str}"; applySymbolicShapesOnLT(schema_str, inputs, shapes); }} """ @@ -523,3 +578,21 @@ def __call__(self, f: NativeFunction) -> List[str]: return ["\n".join([f"{shape_sig.shape_decl};"])] else: return [] + + +def generate_non_native_lazy_ir_nodes( + non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR +) -> List[str]: + """Generate the non-native lazy IR node classes""" + nodes = [] + for op in non_native: + # Set default properties for Non-Native IRs + properties = LazyIrProperties("ShapeCache") + for p in op.get("properties", []): + setattr(properties, p, True) + + schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties) + schema.opkind = op.get("opkind") + nodes.append(gen_lazy_ir.gen(schema)[0]) + + return nodes diff --git a/torchgen/dest/lazy_ts_lowering.py b/torchgen/dest/lazy_ts_lowering.py index 34470d776f66b..b84c625836bf4 100644 --- a/torchgen/dest/lazy_ts_lowering.py +++ b/torchgen/dest/lazy_ts_lowering.py @@ -1,15 +1,10 @@ -from typing import Union -from torchgen.model import NativeFunction, NativeFunctionsGroup from torchgen.api.lazy import LazyIrSchema from torchgen.api.types import OptionalCType -def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: +def ts_lowering_body(schema: LazyIrSchema) -> str: # for now, we just want one IR class decl and soon after also the method defs # and we use the functional version not out/inplace. - func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func - schema = LazyIrSchema(func) - emplace_arguments = [] for arg in schema.positional_args: if arg.is_lazy_value: @@ -47,7 +42,7 @@ def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: {emplace_arguments_str} {emplace_kwarguments} torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); - CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)}); + CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); return {schema.aten_name}_out; """ diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index a84d42aec5ac9..5a7e90b203e9b 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -61,6 +61,7 @@ def parse_backend_yaml( "supported", "autograd", "full_codegen", + "non_native", ] backend = yaml_values.pop("backend", None) @@ -98,6 +99,9 @@ def parse_backend_yaml( full_codegen = yaml_values.pop("full_codegen", []) supported.extend(full_codegen) + # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py + non_native = yaml_values.pop("non_native", {}) + assert ( len(yaml_values.keys()) == 0 ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \ diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 6b2f2e5aaceb7..4da8a80f24c05 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -5,8 +5,10 @@ import yaml from collections import namedtuple, Counter from typing import ( + Any, List, Dict, + Tuple, Union, Sequence, Optional, @@ -106,10 +108,10 @@ ) -def parse_full_codegen_ops( +def parse_native_functions_keys( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], -) -> List[OperatorName]: +) -> Tuple[List[OperatorName], List[Any]]: native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f @@ -124,12 +126,10 @@ def parse_full_codegen_ops( assert isinstance(yaml_values, dict) full_codegen = yaml_values.pop("full_codegen", []) - assert isinstance( - full_codegen, list - ), f'expected "full_codegen" to be a list, but got: {full_codegen}' - full_codegen = [OperatorName.parse(name) for name in full_codegen] - - return full_codegen + non_native = yaml_values.pop("non_native", []) + assert isinstance(full_codegen, list) + assert isinstance(non_native, list) + return [OperatorName.parse(name) for name in full_codegen], non_native def validate_shape_inference_header( @@ -150,13 +150,16 @@ def validate_shape_inference_header( ) # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired. - for decl in expected_shape_infr_decls: - assert ( - decl in shape_infr_decl_lines - ), f"""Missing shape inference function.\n + missing_decls = [ + decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines + ] + if missing_decls: + raise Exception( + f"""Missing shape inference function.\n Please add declare this function in {shape_inference_hdr}:\n and implement it in the the corresponding shape_inference.cpp file.\n -{decl}""" +{os.linesep.join(missing_decls)}""" + ) class default_args: @@ -324,7 +327,9 @@ def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: autograd_key = parsed_backend_yaml.autograd_key cpp_namespace = parsed_backend_yaml.cpp_namespace backend_indices = parsed_backend_yaml.backend_indices - full_codegen = parse_full_codegen_ops(source_yaml, grouped_native_functions) + full_codegen, non_native = parse_native_functions_keys( + source_yaml, grouped_native_functions + ) def concat_map_codegen( func: Callable[[NativeFunction], Sequence[str]], @@ -476,6 +481,10 @@ def concat_map_codegen( }, ) # Generate IR node classes + lazy_ir_obj = lazy_ir_generator( + backend_indices[backend_key], backend_name, node_base + ) + fm.write_with_template( "LazyIr.h", "LazyIr.h", @@ -492,16 +501,35 @@ def concat_map_codegen( "vector", ] ], - "lazy_ir_inc": [ - f'#include "{path}"' - for path in [node_base_hdr if node_base_hdr is not None else None] - if path is not None - ], + "lazy_ir_inc": [f'#include "{node_base_hdr}"'] + if node_base_hdr is not None + else [], "ir_declarations": list( - concat_map_codegen( - lazy_ir_generator(backend_indices[backend_key], node_base), - grouped_native_functions, - ) + concat_map_codegen(lazy_ir_obj, grouped_native_functions) + ), + "namespace_prologue": ns_helper.prologue, + "namespace_epilogue": ns_helper.epilogue, + }, + ) + + # Generate Non Native IR Node classes + fm.write_with_template( + "LazyNonNativeIr.h", + "LazyNonNativeIr.h", + lambda: { + "lazy_non_native_ir_inc": [ + f"#include <{path}>" + for path in [ + "torch/csrc/lazy/core/ir.h", + "torch/csrc/lazy/core/ir_builder.h", + "torch/csrc/lazy/core/internal_ops/ltc_ops.h", + "torch/csrc/lazy/core/shape_inference.h", + ] + + ([node_base_hdr] if node_base_hdr else []) + if path + ], + "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes( + non_native, lazy_ir_obj ), "namespace_prologue": ns_helper.prologue, "namespace_epilogue": ns_helper.epilogue, diff --git a/torchgen/model.py b/torchgen/model.py index 53448a82a31d4..9efe7425a51ed 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1072,18 +1072,14 @@ def schema_order_arguments(self) -> Iterator["Argument"]: self.arguments.out, ) + decl_re = re.compile(r"(?P[^\(]+)\((?P.*)\) -> (?P.*)") + @staticmethod def parse(func: str) -> "FunctionSchema": # We should probably get a proper parser here - assert ( - " -> " in func - ), "function schema missing return type (spaces are mandatory)" - last_index = func.rfind(" -> ") - func_decl = func[:last_index] - return_decl = func[last_index + len(" -> ") :] - ops, args = func_decl.split("(", 1) - assert args[-1] == ")", "Expecting closing )" - args = args[:-1] + decls = FunctionSchema.decl_re.findall(func) + assert len(decls) == 1, f"Invalid function schema: {func}" + ops, args, return_decl = decls[0] name = OperatorName.parse(ops) arguments = Arguments.parse(args) returns = parse_returns(return_decl)