From fd8e55cb9e32510b0e2bba8b16c6142a1ae7e609 Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Thu, 17 Feb 2022 20:15:37 -0500 Subject: [PATCH] Add initial LTC backend (#610) * Add initial LTC backend skeleton * Disable CI build and move TorchMLIRPyTorch.cmake --- .github/workflows/buildAndTest.yml | 2 + python/CMakeLists.txt | 9 + .../cmake/modules/TorchMLIRPyTorch.cmake | 0 python/torch_mlir/csrc/CMakeLists.txt | 43 +++++ python/torch_mlir/csrc/README.md | 19 +++ .../torch_mlir/csrc/backend/backend_impl.cc | 159 ++++++++++++++++++ python/torch_mlir/csrc/backend/backend_impl.h | 136 +++++++++++++++ .../csrc/backend/mlir_lowering_context.cc | 93 ++++++++++ .../csrc/backend/mlir_lowering_context.h | 55 ++++++ python/torch_mlir/csrc/backend/mlir_node.cc | 124 ++++++++++++++ python/torch_mlir/csrc/backend/mlir_node.h | 71 ++++++++ python/torch_mlir/csrc/utils/exception.h | 20 +++ .../torch/importer/jit_ir/CMakeLists.txt | 2 +- 13 files changed, 732 insertions(+), 1 deletion(-) rename python/torch_mlir/{dialects/torch/importer/jit_ir => }/cmake/modules/TorchMLIRPyTorch.cmake (100%) create mode 100644 python/torch_mlir/csrc/CMakeLists.txt create mode 100644 python/torch_mlir/csrc/README.md create mode 100644 python/torch_mlir/csrc/backend/backend_impl.cc create mode 100644 python/torch_mlir/csrc/backend/backend_impl.h create mode 100644 python/torch_mlir/csrc/backend/mlir_lowering_context.cc create mode 100644 python/torch_mlir/csrc/backend/mlir_lowering_context.h create mode 100644 python/torch_mlir/csrc/backend/mlir_node.cc create mode 100644 python/torch_mlir/csrc/backend/mlir_node.h create mode 100644 python/torch_mlir/csrc/utils/exception.h diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 6440d370c3a38..e99c8c2477b18 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -5,6 +5,8 @@ on: branches: - main pull_request: + branches: + - main workflow_dispatch: jobs: diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index c1bd480a3cc9b..fa617b34cc4ab 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -56,6 +56,12 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main LLVMSupport ) +################################################################################ +# Lazy Tensor Core +################################################################################ + +add_subdirectory(torch_mlir/csrc) + ################################################################################ # Optionally handle JIT IR importer. ################################################################################ @@ -128,5 +134,8 @@ endif() # TODO: Add after macOS builds are fixed #add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example) +# Add Torch-MLIR LTC backend as dependency +add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) + add_subdirectory(test) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/cmake/modules/TorchMLIRPyTorch.cmake b/python/torch_mlir/cmake/modules/TorchMLIRPyTorch.cmake similarity index 100% rename from python/torch_mlir/dialects/torch/importer/jit_ir/cmake/modules/TorchMLIRPyTorch.cmake rename to python/torch_mlir/cmake/modules/TorchMLIRPyTorch.cmake diff --git a/python/torch_mlir/csrc/CMakeLists.txt b/python/torch_mlir/csrc/CMakeLists.txt new file mode 100644 index 0000000000000..05f34040cf937 --- /dev/null +++ b/python/torch_mlir/csrc/CMakeLists.txt @@ -0,0 +1,43 @@ +#------------------------------------------------------------------------------- +# Setup PyTorch/LTC +#------------------------------------------------------------------------------- + + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") +include(TorchMLIRPyTorch) +TorchMLIRProbeForPyTorchInstall() +find_package(Torch 1.11 REQUIRED) + +TorchMLIRConfigurePyTorch() + +include_directories(BEFORE + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${Python3_INCLUDE_DIRS} +) +link_directories("${TORCH_INSTALL_PREFIX}/lib") + + +add_library(torch_mlir_ltc_backend SHARED + backend/backend_impl.cc + backend/mlir_lowering_context.cc + backend/mlir_node.cc +) + +target_link_libraries(torch_mlir_ltc_backend + TorchMLIRAggregateCAPI + ${TORCH_LIBRARIES} + ${Python3_LIBRARIES} + torch_python +) + +message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wpedantic") +set_target_properties(torch_mlir_ltc_backend PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/" + OUTPUT_NAME _MLIR_LTC + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" + CXX_VISIBILITY_PRESET "hidden" + COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wpedantic" +) diff --git a/python/torch_mlir/csrc/README.md b/python/torch_mlir/csrc/README.md new file mode 100644 index 0000000000000..80937c682b296 --- /dev/null +++ b/python/torch_mlir/csrc/README.md @@ -0,0 +1,19 @@ +# Torch-MLIR Lazy Tensor Core Backend + +Contained within this directory are the components that implements the +Torch-MLIR LTC backend. + +The components are subclasses of the backend API interface classes found under +[torch/csrc/lazy/backend](https://github.com/pytorch/pytorch/tree/master/torch/csrc/lazy/backend). + +Importantly, the subclasses are still abstract classes. Pure virtual methods +such as `Compile` were purposefully not overriden as Torch-MLIR does not know +how to compile the model for the target hardware. + +The intent is that vendor hardware specific plugins will subclass the Torch-MLIR +backend classes and override the remaining pure virtual functions to complete +the backend. + +The Torch-MLIR LTC backend's job is to perform the lowering from ATen to MLIR. A +hardware vendor's backend job is to take care of the actual compile and +execution of the lowered MLIR. \ No newline at end of file diff --git a/python/torch_mlir/csrc/backend/backend_impl.cc b/python/torch_mlir/csrc/backend/backend_impl.cc new file mode 100644 index 0000000000000..96ebbcda9bc0b --- /dev/null +++ b/python/torch_mlir/csrc/backend/backend_impl.cc @@ -0,0 +1,159 @@ +#include +#include +#include +#include + +#include "backend_impl.h" +#include "mlir_lowering_context.h" +#include "../utils/exception.h" + +namespace torch { +namespace lazy { + +struct MlirBackendData::Info : public BackendData::Info { + at::Tensor tensor; + c10::optional scalar; + + Info() {} + Info(const Info& other) : + tensor{other.tensor}, scalar{other.scalar} {} + Info(const at::Tensor& tensor) : tensor{tensor} {} + Info(const at::Scalar& scalar) : scalar{scalar} {} +}; + +MlirBackendData::MlirBackendData(BackendDevice device, Shape shape) : + BackendData(device, shape) { + auto info = std::make_shared(); + SetInfo(info); +} +MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device) : + BackendData(device, torch::lazy::Shape(scalar.type(), {})) { + auto info = std::make_shared(scalar); + SetInfo(info); +} +MlirBackendData::MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape) : + BackendData(device, shape) { + auto info = std::make_shared(tensor); + SetInfo(info); +} + +BackendData::Handle MlirBackendData::GetHandle() { return reinterpret_cast(this); } + +void MlirBackendData::Assign(const BackendData& data) { + MlirBackendData::Info* info = + dynamic_cast(data.info()); + TORCH_CHECK( + info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info." + ); + auto new_info = std::make_shared(*info); + SetInfo(new_info); +} + +bool MlirBackendData::HasValue() const { + return bool(info()); +} + +/** + * Initialization/Teardown + * */ +void MlirBackendImpl::PrepareToExit() const {} + +/** + * Data Transfer + * */ + +BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor( + const at::Tensor& tensor, + const Shape& shape, + const BackendDevice& device +) const { + return std::make_shared(tensor, device, shape); +} + +BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar( + const at::Scalar& scalar, + const torch::lazy::BackendDevice& device +) const { + return std::make_shared(scalar, device); +} + +BackendDataPtr MlirBackendImpl::CreateDataPlaceholder( + const BackendDevice& device, const Shape& shape +) const { + return std::make_shared(device, shape); +} + +at::Tensor MlirBackendImpl::MakeTensorFromComputationData( + const BackendDataPtr data, + c10::optional logical_scalar_type +) const { + MlirBackendData::Info* info = + dynamic_cast(data->info()); + TORCH_CHECK( + info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info." + ); + return info->tensor; +} + +/** + * Lowering, Compilation, Execution + * */ + +std::unique_ptr MlirBackendImpl::CreateLoweringContext( + const std::string& name, + BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status +) const { + return std::make_unique( + name, + std::forward(device), + std::forward>(post_order), + std::forward(emit_status) + ); +} + +std::unique_ptr MlirBackendImpl::CreateLoweringContext( + const std::string& name, BackendDevice device +) const { + return std::make_unique( + name, std::forward(device) + ); +} + +/** + * Device Configuration + * */ + +// Set or get the default device type. +// For backends used with virtual c10:: Devices, this configures what real +// device type the backend should use, and matters if the backend supports +// more than one type of real device. + +// Specify which aten device should be used for eager fallback +// may change depending on current 'Default' DeviceType +at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const { + return at::DeviceType::CPU; +} + + +// Query all available backend devices +std::vector MlirBackendImpl::GetBackendDevices() const { + return { + GetBackendDevice(c10::Device(c10::kCPU, 0)), + GetBackendDevice(c10::Device(c10::kLazy, 0)) + }; +} + +// Map a particular c10:: device to a concrete backend device +// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are +// virtual devices, meaning they may map to a gpu, tpu, etc. behind the +// scenes. In the future, non-virtual c10:: devices may also use lazy tensors +// through a mode, in which case these APIs should still work, but should be +// identity mappings. +BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const { + return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index()); +} + +} // lazy +} // torch diff --git a/python/torch_mlir/csrc/backend/backend_impl.h b/python/torch_mlir/csrc/backend/backend_impl.h new file mode 100644 index 0000000000000..52055f46081b3 --- /dev/null +++ b/python/torch_mlir/csrc/backend/backend_impl.h @@ -0,0 +1,136 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace lazy { + +class MlirBackendData : public torch::lazy::BackendData { + public: + struct Info; + + MlirBackendData(torch::lazy::BackendDevice device, torch::lazy::Shape shape); + MlirBackendData(const at::Scalar& scalar, torch::lazy::BackendDevice device); + MlirBackendData(const at::Tensor& tensor, torch::lazy::BackendDevice device, torch::lazy::Shape shape); + + virtual torch::lazy::BackendData::Handle GetHandle() override; + + virtual void Assign(const torch::lazy::BackendData& data) override; + + virtual bool HasValue() const override; +}; + +class MlirBackendImpl : public torch::lazy::BackendImplInterface { +public: + /** + * Initialization/Teardown + * */ + virtual void PrepareToExit() const override; + + /** + * Configuration + * */ + // virtual void SetRngSeed(size_t seed) const = 0; + + /** + * Data Transfer + * */ + + virtual torch::lazy::BackendDataPtr MakeComputationDataFromTensor( + const at::Tensor& tensor, + const torch::lazy::Shape& shape, + const torch::lazy::BackendDevice& device + ) const override; + + virtual torch::lazy::BackendDataPtr MakeComputationDataFromScalar( + const at::Scalar& scalar, + const torch::lazy::BackendDevice& device + ) const override; + + virtual torch::lazy::BackendDataPtr CreateDataPlaceholder( + const torch::lazy::BackendDevice& device, const torch::lazy::Shape& shape + ) const override; + + virtual at::Tensor MakeTensorFromComputationData( + const torch::lazy::BackendDataPtr data, + c10::optional logical_scalar_type + ) const override; + + /** + * Lowering, Compilation, Execution + * */ + + virtual std::unique_ptr CreateLoweringContext( + const std::string& name, + torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status + ) const override; + + virtual std::unique_ptr CreateLoweringContext( + const std::string& name, torch::lazy::BackendDevice device + ) const override; + + // TODO(whc) need to keep this? + // virtual std::vector GetCompilationDevices( + // const std::string& device, c10::ArrayRef devices + // ) const = 0; + + // virtual std::vector Compile( + // std::vector instances + // ) const = 0; + + // virtual std::vector ExecuteComputation( + // torch::lazy::Computation& computation, + // c10::ArrayRef arguments, + // const torch::lazy::BackendDevice& device + // ) const = 0; + + /** + * Device Configuration + * */ + + // Set or get the default device type. + // For backends used with virtual c10:: Devices, this configures what real + // device type the backend should use, and matters if the backend supports + // more than one type of real device. + + // virtual std::shared_ptr GetDefaultDeviceType() const = 0; + // virtual void SetDefaultDeviceType(std::string device_type) = 0; + + // Specify which aten device should be used for eager fallback + // may change depending on current 'Default' DeviceType + virtual at::DeviceType EagerFallbackDeviceType() const override; + + + // Query all available backend devices + virtual std::vector GetBackendDevices() const override; + + // Map a particular c10:: device to a concrete backend device + // Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are + // virtual devices, meaning they may map to a gpu, tpu, etc. behind the + // scenes. In the future, non-virtual c10:: devices may also use lazy tensors + // through a mode, in which case these APIs should still work, but should be + // identity mappings. + virtual torch::lazy::BackendDevice GetBackendDevice(c10::Device device) const override; + + + /** + * Debug/Metrics + * */ + + // virtual std::map GetMetrics() const = 0; + + // virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; + + // virtual std::string GetComputationBackendText( + // const torch::lazy::ComputationPtr computation + // ) const = 0; + +}; + +} // lazy +} // torch diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.cc b/python/torch_mlir/csrc/backend/mlir_lowering_context.cc new file mode 100644 index 0000000000000..e122843dec740 --- /dev/null +++ b/python/torch_mlir/csrc/backend/mlir_lowering_context.cc @@ -0,0 +1,93 @@ +#include + +#include "mlir_lowering_context.h" +#include "../utils/exception.h" + + +namespace torch { +namespace lazy { + +MlirLoweringContext::MlirLoweringContext( + const std::string& name, BackendDevice device +) : LoweringContext(name, std::forward(device)) {} + +MlirLoweringContext::MlirLoweringContext( + const std::string& name, + BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status +) : LoweringContext( + name, + std::forward(device), + std::forward>(post_order), + std::forward(emit_status) +) {} + +int MlirComputation::parameters_size() const { + UNIMPLEMENTED_ERROR("MlirComputation::parameters_size"); +} + +const std::vector& MlirComputation::parameter_shapes() const { + UNIMPLEMENTED_ERROR("MlirComputation::parameter_shapes"); +} + +const std::vector& MlirComputation::parameter_names() const { + UNIMPLEMENTED_ERROR("MlirComputation::parameter_names"); +} + +const torch::lazy::Shape& MlirComputation::result_shape() const { + UNIMPLEMENTED_ERROR("MlirComputation::result_shape"); +} + + +// Get the shape of the result tuple component, given by index. +torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const { + UNIMPLEMENTED_ERROR("MlirLoweringContext::GetResultShape( " << index << " )"); +} + +// Adds the given output as a component of the result tuple and returns its +// assigned position within the tuple. +size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) { + const torch::lazy::Node* node; + auto it = emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { + node = output.node; + + auto post_order = Util::ComputePostOrder(node, &emit_status_); + for (auto po_node : post_order) { + // TODO: uncomment after lowering is implemented + // bool ok = lowering_->Lower(node); + // TORCH_CHECK(ok, "Failed to lower: ", node->ToString()); + } + emitted_outputs_[output] = node; + } else { + node = it->second; + } + result_tuple_.emplace_back(node); + return result_tuple_.size() - 1; +} + +// Associates the given output with the input parameter of the given index and +// shape. Only used for the operator-by-operator execution, mostly for +// debugging purposes. +void MlirLoweringContext::AddParameter( + const torch::lazy::Output& output, + size_t index, + const torch::lazy::Shape& shape, + const std::string& name +) { + UNIMPLEMENTED_ERROR("MlirLoweringContext::AddParameter"); +} + +// Build the computation capturing all the operations created with the +// embedded builder (returned by the builder() API). +ComputationPtr MlirLoweringContext::Build() { + for (const torch::lazy::Node* output : result_tuple_) { + + } + return std::make_shared(); +} + + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.h b/python/torch_mlir/csrc/backend/mlir_lowering_context.h new file mode 100644 index 0000000000000..0fac168ec9772 --- /dev/null +++ b/python/torch_mlir/csrc/backend/mlir_lowering_context.h @@ -0,0 +1,55 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace lazy { + +class MlirComputation : public torch::lazy::Computation { + public: + int parameters_size() const override; + + virtual const std::vector& parameter_shapes() const override; + + virtual const std::vector& parameter_names() const override; + + virtual const torch::lazy::Shape& result_shape() const override; +}; + +class MlirLoweringContext : public torch::lazy::LoweringContext { + public: + + MlirLoweringContext(const std::string& name, torch::lazy::BackendDevice device); + MlirLoweringContext(const std::string& name, + torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status); + + // Get the shape of the result tuple component, given by index. + virtual torch::lazy::Shape GetResultShape(size_t index) const override; + + // Adds the given output as a component of the result tuple and returns its + // assigned position within the tuple. + virtual size_t AddResult(const torch::lazy::Output& output) override; + + // Associates the given output with the input parameter of the given index and + // shape. Only used for the operator-by-operator execution, mostly for + // debugging purposes. + virtual void AddParameter(const torch::lazy::Output& output, + size_t index, + const torch::lazy::Shape& shape, + const std::string& name) override; + + // Build the computation capturing all the operations created with the + // embedded builder (returned by the builder() API). + virtual torch::lazy::ComputationPtr Build() override; + + private: + std::vector result_tuple_; + torch::lazy::OutputMap emitted_outputs_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_node.cc b/python/torch_mlir/csrc/backend/mlir_node.cc new file mode 100644 index 0000000000000..0db1b58981562 --- /dev/null +++ b/python/torch_mlir/csrc/backend/mlir_node.cc @@ -0,0 +1,124 @@ +#include + +#include "mlir_node.h" +#include "../utils/exception.h" + + +namespace torch { +namespace lazy { + +namespace { + +hash_t OperandHashes(const OpList& operands, const hash_t& seed, const bool bakeInSizes) { + hash_t hash = seed; + for (auto& operand : operands) { + if (!operand) { + hash = HashCombine(hash, static_cast(kNullOpt)); + continue; + } + auto operand_hash = bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes(); + hash = HashCombine(hash, operand_hash); + } + return hash; +} + +hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) { + hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes)); + return HashCombine(h, hash_seed); +} + +} // namespace + + +MlirNode::MlirNode( + OpKind op, OpList operands, std::vector&& shapes, + size_t num_outputs, hash_t hash_seed +) : Node( + op, num_outputs, + /* node_hash */ HashCombine(op.hash(), hash_seed), + /* dag_hash */ + [&](bool bakeInSizes) -> hash_t { + return OperandHashes(operands, HashCombine(op.hash(), hash_seed), bakeInSizes); + } + ), + shapes_(std::move(shapes)) { + + for (auto& operand : operands) { + // Ideally, optional operands should be filtered by the leaf node classes, + // but it's just much easier to do it here. + if (!operand) { + continue; + } + + AddOperand(operand.node, operand.index); + } +} + +MlirNode::MlirNode( + OpKind op, OpList operands, + const std::function& shape_fn, + size_t num_outputs, hash_t hash_seed +) : MlirNode( + op, operands, std::vector{}, num_outputs, hash_seed + ) { + shapes_.push_back(GetOpShape(shape_fn)); +} + +MlirNode::MlirNode( + OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed +) : MlirNode(op, operands, std::vector{}, num_outputs, hash_seed) {} + +void MlirNode::SetShapeDeferred( + const std::function& shape_fn +) { + shapes_.push_back(GetOpShape(shape_fn)); +} + +MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) + : Node( + op, num_outputs, + [&](bool bakeInSizes) -> hash_t { + return GetOpHash(op, shape, hash_seed, bakeInSizes); + } + ) { + shapes_.push_back(std::move(shape)); +} + + +using ShapeCache = Cache; + +constexpr const int torch_lazy_shape_cache_size = 4096; + +ShapeCache* GetShapeCache() { + static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size); + return cache; +} + +Shape MlirNode::GetOpShape(const std::function& shape_fn) const { + ShapeCache* shape_cache = GetShapeCache(); + auto shape = shape_cache->Get(hash()); + if (shape == nullptr) { + shape = shape_cache->Add( + hash(), std::make_shared(shape_fn()) + ); + } + return *shape; +} + + +const std::vector& MlirNode::operands() const { + return operands_as_outputs_; +} + +const Output& MlirNode::operand(size_t i) const { + return operands_as_outputs_.at(i); +} + +void MlirNode::AddOperand(NodePtr node, size_t index) { + CHECK_LT(index, node->num_outputs()); + operands_.push_back(std::move(node)); + operands_as_outputs_.emplace_back(operands_.back().get(), index); +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_node.h b/python/torch_mlir/csrc/backend/mlir_node.h new file mode 100644 index 0000000000000..48b70fe28ba0c --- /dev/null +++ b/python/torch_mlir/csrc/backend/mlir_node.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include + +#include "mlir_lowering_context.h" +#include "../utils/exception.h" + +namespace torch { +namespace lazy { + +typedef std::vector MlirOpVector; +typedef NodePtr MlirFunction; + + +class MlirNode : public torch::lazy::Node { + + public: + MlirNode( + OpKind op, OpList operands, std::vector&& shapes, + size_t num_outputs = 1, hash_t hash_seed = kHashSeed + ); + + // Same as the constructor above, but the shape is generated by a function, + // only if needed (shape cache miss). + MlirNode( + OpKind op, OpList operands, + const std::function& shape_fn, + size_t num_outputs = 1, hash_t hash_seed = kHashSeed + ); + + // The shape is set later. + MlirNode( + OpKind op, OpList operands, size_t num_outputs = 1, + hash_t hash_seed = kHashSeed + ); + + void SetShapeDeferred(const std::function& shape_fn); + + // Contructor used to create leaf nodes. + MlirNode( + OpKind op, Shape shape, size_t num_outputs = 1, hash_t hash_seed = kHashSeed + ); + + Shape GetOpShape(const std::function& shape_fn) const; + + const std::vector& operands() const override; + + const Output& operand(size_t i) const override; + + virtual MlirOpVector Lower( + MlirFunction function, + MlirLoweringContext* loctx + ) const = 0; + + private: + // Adds node's index output number as operand. + void AddOperand(NodePtr node, size_t index = 0); + + std::vector shapes_; + // A node holds a real reference to its operands. + std::vector operands_; + // Outputs do not hold references on the nodes, and neither do the uses, since + // otherwise we get into circular reference counting. + std::vector operands_as_outputs_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/utils/exception.h b/python/torch_mlir/csrc/utils/exception.h new file mode 100644 index 0000000000000..a9dafdcfd209e --- /dev/null +++ b/python/torch_mlir/csrc/utils/exception.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +#define UNIMPLEMENTED_ERROR(msg) \ + { \ + std::ostringstream err; \ + err << "Unimplemented Error: " << msg; \ + throw std::runtime_error(err.str()); \ + } + + +#define UNSUPPORTED_ERROR(msg) \ + { \ + std::ostringstream err; \ + err << "Unsupported Error: " << msg; \ + throw std::runtime_error(err.str()); \ + } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt b/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt index ddf27295f25cc..3ec2766f46b50 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt @@ -2,7 +2,7 @@ # Setup PyTorch #------------------------------------------------------------------------------- -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") include(TorchMLIRPyTorch) TorchMLIRProbeForPyTorchInstall() find_package(Torch 1.8 REQUIRED)