Skip to content

Commit

Permalink
Propagate parameter names to TorchMlirComputation (llvm#1420)
Browse files Browse the repository at this point in the history
* Propagate parameter name to MLIR

* Add TorchMlirNode Constructor Hook

* Make func_op mutable

- Purpose of this is to allow modification of func_op by subclass
  backend

* Clean up unnecessary changes

* Remove unnecessary attribute case

* Address PR comments
  • Loading branch information
antoniojkim authored Sep 29, 2022
1 parent 8f608c0 commit fa5a8e2
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 50 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)

if(TORCH_MLIR_ENABLE_LTC)
set(ENV{TORCH_MLIR_ENABLE_LTC} 1)
message(STATUS "LTC Backend build is enabled")
else()
set(ENV{TORCH_MLIR_ENABLE_LTC} 0)
message(STATUS "LTC Backend build is disabled")
endif()

torch_mlir_add_llvm_external_project(
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ add_library(torch_mlir_ltc_backend SHARED
mlir_node.cpp
ops/device_data.cpp
ops/generic.cpp
utils/tensor_utils.cpp
)
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)

Expand Down
21 changes: 10 additions & 11 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@ namespace lazy {

TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>()) {
info_(std::make_shared<TorchMlirBackendData::Info>()) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info)
: BackendData(device, shape), info_(info) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})),
info_(std::make_unique<TorchMlirBackendData::Info>(scalar)) {
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>(tensor)) {
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
PRINT_FUNCTION();
}

Expand All @@ -54,18 +59,12 @@ void TorchMlirBackendData::Assign(const BackendData& data) {
torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");

TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
TORCH_CHECK(
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");

info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
info_ = torch_mlir_data->info_;
}

bool TorchMlirBackendData::HasValue() const { return bool(info_); }

TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
BackendData::Info* TorchMlirBackendData::mlir_info() const {
return info_.get();
}

Expand Down
24 changes: 13 additions & 11 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#pragma once

#include <memory>
#include <sstream>

#include <torch/csrc/lazy/backend/backend_data.h>
Expand All @@ -33,22 +34,23 @@ class TORCH_API TorchMlirBackendData : public BackendData {
bool requires_grad;
std::string name;

Info() {}
Info() {
static int i = 0;
std::stringstream ss;
ss << "placeholder" << i;
name = ss.str();
++i;
}
Info(const Info& other)
: tensor{other.tensor}, scalar{other.scalar},
requires_grad{other.requires_grad}, name{other.name} {}
Info(const at::Tensor& tensor)
: tensor{tensor}, requires_grad{tensor.requires_grad()} {
static int num_tensors = 0;
std::ostringstream oss;
oss << "tensor" << num_tensors;
this->name = oss.str();
++num_tensors;
}
: tensor{tensor}, requires_grad{tensor.requires_grad()} {}
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
};

TorchMlirBackendData(BackendDevice device, Shape shape);
TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info);
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape);
Expand All @@ -59,10 +61,10 @@ class TORCH_API TorchMlirBackendData : public BackendData {

virtual bool HasValue() const override;

TorchMlirBackendData::Info* mlir_info() const;
BackendData::Info* mlir_info() const;

private:
std::unique_ptr<TorchMlirBackendData::Info> info_;
protected:
std::shared_ptr<BackendData::Info> info_;
};

class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
Expand Down
42 changes: 35 additions & 7 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "torch-mlir-c/Registration.h"

#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
#include "backend_impl.h"
#include "mlir_lowering_context.h"
#include "mlir_node.h"
#include "torch-mlir-c/Registration.h"
#include "utils/debug.h"
#include "utils/exception.h"
#include "utils/string_utils.h"
#include "utils/sys_utils.h"

namespace torch {
namespace lazy {
Expand Down Expand Up @@ -141,7 +143,7 @@ ComputationPtr TorchMlirLoweringContext::Build() {
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});

return std::make_shared<TorchMlirComputation>(
func_op, mlir_context_, graph_, input_output_aliases_);
func_op, mlir_context_, graph_, parameter_names_, input_output_aliases_);
}

torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
Expand Down Expand Up @@ -194,7 +196,8 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
torch::jit::Value* param =
graph_->addInput(c10::str("p", parameters_.size()));

auto info = mlir_data->mlir_info();
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
TORCH_CHECK(info, "Expected TorchMlirBackendData::Info");
if (info->scalar.has_value()) {
auto& scalar = info->scalar.value();
if (scalar.isFloatingPoint()) {
Expand All @@ -213,6 +216,10 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
/*sizes=*/c10::VaryingShape<int64_t>(data->shape().sizes()),
/*strides=*/c10::VaryingShape<int64_t>(),
/*requires_grad=*/c10::nullopt));

if (info->name != "" && !startswith(info->name, "input")) {
parameter_names_[parameters_.size()] = info->name;
}
}

it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
Expand Down Expand Up @@ -286,16 +293,22 @@ void TorchMlirLoweringContext::RegisterMlirDialects() {
TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), input_output_aliases_(input_output_aliases) {
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
graph_(graph), input_output_aliases_(input_output_aliases),
parameters_map_(parameters_map) {

num_parameters_ = graph_->inputs().size();

parameter_names_.reserve(parameters_map_.size());
for (auto kv : parameters_map_) {
parameter_names_.emplace_back(kv.second);
}
}

int TorchMlirComputation::parameters_size() const {
return parameter_names_.size();
return num_parameters_;
}

const std::vector<torch::lazy::Shape>&
Expand All @@ -309,6 +322,10 @@ const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
return parameter_names_;
}

const std::unordered_map<int, std::string>& TorchMlirComputation::parameters_map() const {
return parameters_map_;
}

const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
Expand All @@ -321,6 +338,10 @@ std::shared_ptr<torch::jit::Graph> TorchMlirComputation::graph() const {

MlirOperation TorchMlirComputation::func_op() const { return func_op_; }

MlirContext TorchMlirComputation::mlir_context() const {
return mlir_context_;
}

const std::string TorchMlirComputation::debug_string() const {
std::stringstream ss;

Expand All @@ -330,6 +351,13 @@ const std::string TorchMlirComputation::debug_string() const {
// MLIR
ss << "MLIR: \n" << to_string() << "\n";

// Parameter names
ss << "Parameter names:\n";
for (auto& p : parameter_names_) {
ss << " " << p << "\n";
}
ss << "\n";

// Input/Output Mapping
ss << "Input/Output Alias Mapping: \n";
for (InputOutputAlias input_output_alias : input_output_aliases_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#pragma once

#include <unordered_map>
#include <vector>

#include <torch/csrc/api/include/torch/jit.h>
Expand Down Expand Up @@ -109,6 +110,7 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
std::shared_ptr<torch::jit::GraphFunction> function_;
MlirContext mlir_context_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
std::unordered_map<int, std::string> parameter_names_;
std::vector<torch::jit::Value*> root_tuple_;
OutputMap<torch::jit::Value*> emitted_outputs_;
};
Expand All @@ -121,6 +123,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases);

int parameters_size() const override;
Expand All @@ -129,17 +132,23 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {

const std::vector<std::string>& parameter_names() const override;

const std::unordered_map<int, std::string>& parameters_map() const;

const torch::lazy::Shape& result_shape() const override;

std::shared_ptr<torch::jit::Graph> graph() const;

MlirOperation func_op() const;

MlirContext mlir_context() const;

virtual const std::string debug_string() const;

virtual const std::string to_string() const override;

protected:
size_t num_parameters_;
std::unordered_map<int, std::string> parameters_map_;
std::vector<std::string> parameter_names_;
std::vector<Shape> parameter_shapes_;
Shape result_shape_;
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ hash_t OperandHashes(

} // namespace


// Adds a static hook that is run after every single TorchMlirNode is initialized
static std::vector<std::function<void(TorchMlirNode*)>> constructor_hooks;
void TorchMlirNode::addConstructorHook(std::function<void(TorchMlirNode*)> f) {
constructor_hooks.emplace_back(f);
}

TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
hash_t hash_seed)
Expand All @@ -48,6 +55,10 @@ TorchMlirNode::TorchMlirNode(
(enableDynamicShape()
? OperandHashes(operands, this->shapes(), hash_seed, false)
: shape_hash_);

for (std::function<void(TorchMlirNode*)>& f : constructor_hooks) {
f(this);
}
}

TorchMlirNode::TorchMlirNode(
Expand All @@ -71,6 +82,15 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }

hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }


TorchMlirNode* TorchMlirNode::mlir_node(int index) {
return dynamic_cast<TorchMlirNode*>(operands_.at(index).get());
}

///////////////////////////////////////////////////////////////////////////////
// TorchMlirTensorList
///////////////////////////////////////////////////////////////////////////////

OpKind TorchMlirTensorList::ClassOpKind() {
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
Expand Down
5 changes: 5 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,17 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node {
TorchMlirNode(
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);

// Adds a static hook that is run after every single TorchMlirNode is constructed
static void addConstructorHook(std::function<void(TorchMlirNode*)>);

~TorchMlirNode() override = default;

hash_t hash() const override;

hash_t shapeHash() const override;

TorchMlirNode* mlir_node(int index);

virtual TorchMlirOpVector
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void DeviceData::propagate_name() {
// Add device data name to backend data
TorchMlirBackendData* mlir_data = dynamic_cast<TorchMlirBackendData*>(data_.get());
TORCH_CHECK(mlir_data);
TorchMlirBackendData::Info* info = mlir_data->mlir_info();
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
TORCH_CHECK(info);
info->name = name_;
}
Expand Down
47 changes: 47 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "tensor_utils.h"

#include "../generated/LazyIr.h"
#include "../mlir_node.h"


namespace torch {
namespace lazy {

bool is_detach_copy(const torch::lazy::Value& value) {
return value->op() == torch::lazy::DetachCopy::ClassOpKind();
}

torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
if (!value) {
return nullptr;
}
torch::lazy::TorchMlirNode* node = dynamic_cast<torch::lazy::TorchMlirNode*>(value.node.get());
while(node) {
if (node->op() == torch::lazy::DeviceData::ClassOpKind()) {
return dynamic_cast<torch::lazy::DeviceData*>(node);
}
else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) {
node = node->mlir_node(0);
}
else {
break;
}
}
}

torch::lazy::DeviceData* device_data_cast(
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device
) {
if (!device) {
device = torch::lazy::GetBackendDevice(tensor);
}
TORCH_CHECK(device);
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
if (lazy_tensor) {
return device_data_cast(lazy_tensor->GetIrValue());
}
return nullptr;
}

} // namespace lazy
} // namespace torch
Loading

0 comments on commit fa5a8e2

Please sign in to comment.