-
Notifications
You must be signed in to change notification settings - Fork 532
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pytorch interface to ATen Dialect (#30)
This patch adds a pytorch interface to npcomp. This interface is modeled after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar to a gpu device or the xla backend). Usage is intended to be something like: dev = torch_mlir.mlir_device() t0 = torch.randn((4,4), device=dev) t1 = torch.randn((4,4), device=dev) t2 = t0 + t1 t2_mlir = torch_mlir.get_mlir( t2 ) t2_cpu = t2.to('cpu') In this case t2_cpu would contain the result of the computation, and t2_mlir contains the mlir description of the computation. Note that this also properly returns backward paths synthesized by pytorch. There are several parts of this: 1) A tensor type (implemented by tensor.* and tensor_impl.*) 2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*) 3) a temporary IR (implemented by ir.cpp) There is also a reference lowering directly from the ATen dialect to C function calls consisting of two parts: 1) The driver that uses the IR to generate MLIR, run Passes and compile the result using mlir::ExecutionEngine (implemented by jit.cpp and mlir_gen.cpp) 2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations are implemented by callbacks into the torch C++ libraries. Some aspects of this are known to be less than optimal, in particular: 1) There's some function definitions that don't live in the file corresponding to their declaration. 2) More aspects of this (e.g. the IR) seem like they should be automatically generated. 3) It's unclear to me how much of the 'IR' is actually necessary, or whether MLIR could be created on the fly. Note that this code is licensed in a way similar to pytorch, with the intention that eventually (when npcomp reaches some maturity) it should be pushed there. (see frontends/pytorch/LICENSE) The code is also structured much closer to the pytorch coding style than the LLVM coding style.
- Loading branch information
1 parent
69cda40
commit 31b3041
Showing
67 changed files
with
36,106 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
if(${TORCH_FOUND}) | ||
add_subdirectory(pytorch) | ||
else() | ||
message("Skipping pytorch frontend, because PyTorch not found!") | ||
endif() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
add_subdirectory(lib) | ||
add_subdirectory(csrc) | ||
add_subdirectory(test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
In order to facilitate future incorporation in pytorch, the code in this | ||
directory (frontends/pytorch) is provided under the below license. | ||
|
||
Copyright (c) 2020 LLVM Foundation. | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright | ||
notice, this list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright | ||
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
|
||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America | ||
and IDIAP Research Institute nor the names of its contributors may be | ||
used to endorse or promote products derived from this software without | ||
specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | ||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
POSSIBILITY OF SUCH DAMAGE. | ||
|
||
The design of this code is highly inspired by the design of the xla device for | ||
pytorch (git@github.com:pytorch/xla.git). The license for pytorch/xla is: | ||
|
||
Copyright (c) 2018 Google Inc. | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright | ||
notice, this list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright | ||
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
|
||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America | ||
and IDIAP Research Institute nor the names of its contributors may be | ||
used to endorse or promote products derived from this software without | ||
specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | ||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
include_directories( | ||
${TORCH_INCLUDE_DIRS} | ||
${TORCH_INSTALL_PREFIX}/include/TH | ||
${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch | ||
${CMAKE_CURRENT_SOURCE_DIR} | ||
${CMAKE_CURRENT_BINARY_DIR} | ||
${PYTHON_INCLUDE_DIRS} | ||
) | ||
link_directories("${TORCH_INSTALL_PREFIX}/lib") | ||
add_library(_torch_mlir SHARED | ||
aten_mlir_bridge.cpp | ||
aten_mlir_type.cpp | ||
aten_mlir_type_default.cpp | ||
device.cpp | ||
init_python_bindings.cpp | ||
ir.cpp | ||
jit.cpp | ||
mlir_gen.cpp | ||
tensor.cpp | ||
tensor_impl.cpp | ||
torch_util.cpp | ||
) | ||
set_target_properties(_torch_mlir PROPERTIES PREFIX "") | ||
|
||
get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS) | ||
target_link_libraries(_torch_mlir | ||
NPCOMPATenDialect | ||
${TORCH_LIBRARIES} | ||
${mlir_libs} | ||
${PYTHON_LIBRARIES} | ||
torch_python | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
//===- aten_mlir_bridge.cpp -------------------------------------*- C++ -*-===// | ||
// | ||
// This file is licensed under a pytorch-style license | ||
// See frontends/pytorch/LICENSE for license information. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
// Structured similarly to code from git@github.com:pytorch/xla.git | ||
|
||
#include "aten_mlir_bridge.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "device.h" | ||
#include "tensor_impl.h" | ||
|
||
namespace torch_mlir { | ||
namespace bridge { | ||
namespace { | ||
|
||
class AtenMLIRDeviceMapper { | ||
public: | ||
static AtenMLIRDeviceMapper *Get(); | ||
|
||
size_t GetDeviceOrdinal(const Device &device) const { | ||
auto it = devices_ordinals_.find(device); | ||
assert(it != devices_ordinals_.end()); | ||
return it->second; | ||
} | ||
|
||
const Device &GetDeviceFromOrdinal(size_t ordinal) const { | ||
return devices_.at(ordinal); | ||
} | ||
|
||
private: | ||
AtenMLIRDeviceMapper() { | ||
std::vector<std::string> local_devices{"mlir:0", "mlir:1", "mlir:2"}; | ||
for (auto &device_str : local_devices) { | ||
devices_.emplace_back(device_str); | ||
devices_ordinals_[devices_.back()] = devices_.size() - 1; | ||
} | ||
} | ||
|
||
std::vector<Device> devices_; | ||
std::map<Device, size_t> devices_ordinals_; | ||
}; | ||
|
||
AtenMLIRDeviceMapper *AtenMLIRDeviceMapper::Get() { | ||
static AtenMLIRDeviceMapper *device_mapper = new AtenMLIRDeviceMapper(); | ||
return device_mapper; | ||
} | ||
|
||
} // namespace | ||
|
||
c10::optional<MLIRTensor> TryGetMLIRTensor(const at::Tensor &tensor) { | ||
MLIRTensorImpl *impl = | ||
dynamic_cast<MLIRTensorImpl *>(tensor.unsafeGetTensorImpl()); | ||
if (impl == nullptr) { | ||
return c10::nullopt; | ||
} | ||
return impl->tensor(); | ||
} | ||
|
||
MLIRTensor GetMLIRTensor(const at::Tensor &tensor) { | ||
auto xtensor = TryGetMLIRTensor(tensor); | ||
assert(xtensor && "Input tensor is not an MLIR tensor"); | ||
return *xtensor; | ||
} | ||
|
||
MLIRTensor GetOrCreateMLIRTensor(const at::Tensor &tensor, | ||
const Device &device) { | ||
if (!tensor.defined()) { | ||
return MLIRTensor(); | ||
} | ||
auto xtensor = TryGetMLIRTensor(tensor); | ||
return xtensor ? *xtensor : MLIRTensor::Create(tensor, device); | ||
} | ||
|
||
std::vector<at::Tensor> MLIRCreateTensorList(const at::TensorList &tensors) { | ||
|
||
std::vector<at::Tensor> aten_device_tensors(tensors.size()); | ||
std::vector<MLIRTensor> device_tensors; | ||
|
||
std::vector<bool> to_translate(tensors.size()); | ||
|
||
for (size_t i = 0; i < tensors.size(); ++i) { | ||
const at::Tensor &tensor = tensors[i]; | ||
if (tensor.defined()) { | ||
auto xtensor = TryGetMLIRTensor(tensor); | ||
if (xtensor) { | ||
to_translate[i] = true; | ||
device_tensors.push_back(*xtensor); | ||
} else { | ||
aten_device_tensors[i] = tensor; | ||
} | ||
} | ||
} | ||
|
||
for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) { | ||
if (to_translate[i]) { | ||
aten_device_tensors[i] = | ||
std::move(device_tensors[defined_pos++].ToTensor()); | ||
} | ||
} | ||
return aten_device_tensors; | ||
} | ||
|
||
c10::optional<Device> GetMLIRDevice(const at::TensorList &tensors) { | ||
for (const auto &tensor : tensors) { | ||
auto device = GetMLIRDevice(tensor); | ||
if (device) { | ||
return device; | ||
} | ||
} | ||
return c10::nullopt; | ||
} | ||
|
||
c10::optional<Device> GetMLIRDevice(const at::TensorOptions &tensor_options) { | ||
if (!tensor_options.has_device()) { | ||
return c10::nullopt; | ||
} | ||
return GetMLIRDevice(tensor_options.device()); | ||
} | ||
|
||
c10::optional<Device> GetMLIRDevice(const c10::Device &device) { | ||
if (device.type() != at::kXLA) { | ||
return c10::nullopt; | ||
} | ||
return AtenDeviceToMLIRDevice(device); | ||
} | ||
|
||
c10::optional<Device> GetMLIRDevice(const at::Tensor &tensor) { | ||
auto xtensor = TryGetMLIRTensor(tensor); | ||
if (!xtensor) { | ||
return c10::nullopt; | ||
} | ||
return xtensor->GetDevice(); | ||
} | ||
|
||
Device AtenDeviceToMLIRDevice(const c10::Device &device) { | ||
assert(device.type() == at::kXLA); | ||
int ordinal = device.has_index() ? device.index() : -1; | ||
if (ordinal < 0) { | ||
c10::Device current_device = MLIRTensorImpl::GetCurrentAtenDevice(); | ||
if (current_device.has_index()) { | ||
ordinal = current_device.index(); | ||
} | ||
} | ||
if (ordinal < 0) { | ||
return *GetDefaultDevice(); | ||
} | ||
return AtenMLIRDeviceMapper::Get()->GetDeviceFromOrdinal(ordinal); | ||
} | ||
|
||
c10::Device MLIRDeviceToAtenDevice(const Device &device) { | ||
// TODO: define our own device and stop hijacking the xla device. | ||
return c10::Device(at::kXLA, | ||
AtenMLIRDeviceMapper::Get()->GetDeviceOrdinal(device)); | ||
} | ||
|
||
at::Tensor MLIRToAtenTensor(MLIRTensor device_tensor, | ||
const at::TensorOptions &tensor_options) { | ||
if (tensor_options.has_device()) { | ||
assert(tensor_options.device().type() != at::kXLA); | ||
} | ||
|
||
at::Tensor tensor = device_tensor.ToTensor(); | ||
|
||
// We need to copy the tensor since it is cached within the MLIRTensor, and | ||
// returning it directly might expose it to in place changes. | ||
return tensor.to(tensor_options, /*non_blocking=*/false, /*copy=*/true); | ||
} | ||
|
||
at::Tensor AtenFromMLIRTensor(MLIRTensor device_tensor) { | ||
assert(!device_tensor.is_null()); | ||
at::Tensor ret = | ||
at::Tensor(c10::make_intrusive<MLIRTensorImpl>(std::move(device_tensor))); | ||
return ret; | ||
} | ||
|
||
at::Tensor CreateMLIRTensor(at::Tensor tensor, | ||
const c10::optional<Device> &device) { | ||
if (tensor.defined() && device) { | ||
MLIRTensor device_tensor = MLIRTensor::Create(std::move(tensor), *device); | ||
tensor = AtenFromMLIRTensor(device_tensor); | ||
} | ||
return tensor; | ||
} | ||
|
||
} // namespace bridge | ||
} // namespace torch_mlir |
Oops, something went wrong.