Skip to content

Commit

Permalink
Add initial LTC backend (#610)
Browse files Browse the repository at this point in the history
* Add initial LTC backend skeleton

* Disable CI build and move TorchMLIRPyTorch.cmake
  • Loading branch information
antoniojkim committed Jul 5, 2022
1 parent be3d14c commit 49fb61b
Show file tree
Hide file tree
Showing 13 changed files with 732 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ on:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:

jobs:
Expand Down
9 changes: 9 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
LLVMSupport
)

################################################################################
# Lazy Tensor Core
################################################################################

add_subdirectory(torch_mlir/csrc)

################################################################################
# Optionally handle JIT IR importer.
################################################################################
Expand Down Expand Up @@ -135,5 +141,8 @@ endif()
# TODO: Add after macOS builds are fixed
#add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example)

# Add Torch-MLIR LTC backend as dependency
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)

add_subdirectory(test)

43 changes: 43 additions & 0 deletions python/torch_mlir/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#-------------------------------------------------------------------------------
# Setup PyTorch/LTC
#-------------------------------------------------------------------------------


list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
include(TorchMLIRPyTorch)
TorchMLIRProbeForPyTorchInstall()
find_package(Torch 1.11 REQUIRED)

TorchMLIRConfigurePyTorch()

include_directories(BEFORE
${TORCH_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")


add_library(torch_mlir_ltc_backend SHARED
backend/backend_impl.cc
backend/mlir_lowering_context.cc
backend/mlir_node.cc
)

target_link_libraries(torch_mlir_ltc_backend
TorchMLIRAggregateCAPI
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python
)

message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wpedantic")
set_target_properties(torch_mlir_ltc_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
OUTPUT_NAME _MLIR_LTC
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wpedantic"
)
19 changes: 19 additions & 0 deletions python/torch_mlir/csrc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Torch-MLIR Lazy Tensor Core Backend

Contained within this directory are the components that implements the
Torch-MLIR LTC backend.

The components are subclasses of the backend API interface classes found under
[torch/csrc/lazy/backend](https://github.com/pytorch/pytorch/tree/master/torch/csrc/lazy/backend).

Importantly, the subclasses are still abstract classes. Pure virtual methods
such as `Compile` were purposefully not overriden as Torch-MLIR does not know
how to compile the model for the target hardware.

The intent is that vendor hardware specific plugins will subclass the Torch-MLIR
backend classes and override the remaining pure virtual functions to complete
the backend.

The Torch-MLIR LTC backend's job is to perform the lowering from ATen to MLIR. A
hardware vendor's backend job is to take care of the actual compile and
execution of the lowered MLIR.
159 changes: 159 additions & 0 deletions python/torch_mlir/csrc/backend/backend_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>

#include "backend_impl.h"
#include "mlir_lowering_context.h"
#include "../utils/exception.h"

namespace torch {
namespace lazy {

struct MlirBackendData::Info : public BackendData::Info {
at::Tensor tensor;
c10::optional<at::Scalar> scalar;

Info() {}
Info(const Info& other) :
tensor{other.tensor}, scalar{other.scalar} {}
Info(const at::Tensor& tensor) : tensor{tensor} {}
Info(const at::Scalar& scalar) : scalar{scalar} {}
};

MlirBackendData::MlirBackendData(BackendDevice device, Shape shape) :
BackendData(device, shape) {
auto info = std::make_shared<MlirBackendData::Info>();
SetInfo(info);
}
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device) :
BackendData(device, torch::lazy::Shape(scalar.type(), {})) {
auto info = std::make_shared<MlirBackendData::Info>(scalar);
SetInfo(info);
}
MlirBackendData::MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape) :
BackendData(device, shape) {
auto info = std::make_shared<MlirBackendData::Info>(tensor);
SetInfo(info);
}

BackendData::Handle MlirBackendData::GetHandle() { return reinterpret_cast<int64_t>(this); }

void MlirBackendData::Assign(const BackendData& data) {
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data.info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."
);
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
SetInfo(new_info);
}

bool MlirBackendData::HasValue() const {
return bool(info());
}

/**
* Initialization/Teardown
* */
void MlirBackendImpl::PrepareToExit() const {}

/**
* Data Transfer
* */

BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
const at::Tensor& tensor,
const Shape& shape,
const BackendDevice& device
) const {
return std::make_shared<MlirBackendData>(tensor, device, shape);
}

BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
const at::Scalar& scalar,
const torch::lazy::BackendDevice& device
) const {
return std::make_shared<MlirBackendData>(scalar, device);
}

BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape
) const {
return std::make_shared<MlirBackendData>(device, shape);
}

at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type
) const {
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data->info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."
);
return info->tensor;
}

/**
* Lowering, Compilation, Execution
* */

std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
const std::string& name,
BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
Util::EmissionMap emit_status
) const {
return std::make_unique<MlirLoweringContext>(
name,
std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)
);
}

std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device
) const {
return std::make_unique<MlirLoweringContext>(
name, std::forward<BackendDevice>(device)
);
}

/**
* Device Configuration
* */

// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.

// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
return at::DeviceType::CPU;
}


// Query all available backend devices
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
return {
GetBackendDevice(c10::Device(c10::kCPU, 0)),
GetBackendDevice(c10::Device(c10::kLazy, 0))
};
}

// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index());
}

} // lazy
} // torch
136 changes: 136 additions & 0 deletions python/torch_mlir/csrc/backend/backend_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#pragma once

#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/shape.h>

namespace torch {
namespace lazy {

class MlirBackendData : public torch::lazy::BackendData {
public:
struct Info;

MlirBackendData(torch::lazy::BackendDevice device, torch::lazy::Shape shape);
MlirBackendData(const at::Scalar& scalar, torch::lazy::BackendDevice device);
MlirBackendData(const at::Tensor& tensor, torch::lazy::BackendDevice device, torch::lazy::Shape shape);

virtual torch::lazy::BackendData::Handle GetHandle() override;

virtual void Assign(const torch::lazy::BackendData& data) override;

virtual bool HasValue() const override;
};

class MlirBackendImpl : public torch::lazy::BackendImplInterface {
public:
/**
* Initialization/Teardown
* */
virtual void PrepareToExit() const override;

/**
* Configuration
* */
// virtual void SetRngSeed(size_t seed) const = 0;

/**
* Data Transfer
* */

virtual torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
const at::Tensor& tensor,
const torch::lazy::Shape& shape,
const torch::lazy::BackendDevice& device
) const override;

virtual torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
const at::Scalar& scalar,
const torch::lazy::BackendDevice& device
) const override;

virtual torch::lazy::BackendDataPtr CreateDataPlaceholder(
const torch::lazy::BackendDevice& device, const torch::lazy::Shape& shape
) const override;

virtual at::Tensor MakeTensorFromComputationData(
const torch::lazy::BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type
) const override;

/**
* Lowering, Compilation, Execution
* */

virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
const std::string& name,
torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status
) const override;

virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
const std::string& name, torch::lazy::BackendDevice device
) const override;

// TODO(whc) need to keep this?
// virtual std::vector<std::string> GetCompilationDevices(
// const std::string& device, c10::ArrayRef<std::string> devices
// ) const = 0;

// virtual std::vector<torch::lazy::ComputationPtr> Compile(
// std::vector<torch::lazy::ComputationPtr> instances
// ) const = 0;

// virtual std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
// torch::lazy::Computation& computation,
// c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
// const torch::lazy::BackendDevice& device
// ) const = 0;

/**
* Device Configuration
* */

// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.

// virtual std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const = 0;
// virtual void SetDefaultDeviceType(std::string device_type) = 0;

// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
virtual at::DeviceType EagerFallbackDeviceType() const override;


// Query all available backend devices
virtual std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;

// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
virtual torch::lazy::BackendDevice GetBackendDevice(c10::Device device) const override;


/**
* Debug/Metrics
* */

// virtual std::map<std::string, Metric> GetMetrics() const = 0;

// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;

// virtual std::string GetComputationBackendText(
// const torch::lazy::ComputationPtr computation
// ) const = 0;

};

} // lazy
} // torch
Loading

0 comments on commit 49fb61b

Please sign in to comment.