Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed Sep 29, 2022
1 parent a1d7108 commit 3e4ec62
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 121 deletions.
3 changes: 0 additions & 3 deletions lib/Dialect/Torch/IR/TorchDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
"argument of !torch.tensor/!torch.vtensor type";
return success();
}
else if (namedAttr.getName().getValue() == "torch.parameter") {
return success();
}

return op->emitError() << "unknown region arg attribute '"
<< namedAttr.getName().getValue() << "'";
Expand Down
48 changes: 0 additions & 48 deletions python/test/lazy_backend/parameter_name.py

This file was deleted.

21 changes: 10 additions & 11 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@ namespace lazy {

TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>()) {
info_(std::make_shared<TorchMlirBackendData::Info>()) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info)
: BackendData(device, shape), info_(info) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})),
info_(std::make_unique<TorchMlirBackendData::Info>(scalar)) {
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>(tensor)) {
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
PRINT_FUNCTION();
}

Expand All @@ -54,18 +59,12 @@ void TorchMlirBackendData::Assign(const BackendData& data) {
torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");

TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
TORCH_CHECK(
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");

info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
info_ = torch_mlir_data->info_;
}

bool TorchMlirBackendData::HasValue() const { return bool(info_); }

TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
BackendData::Info* TorchMlirBackendData::mlir_info() const {
return info_.get();
}

Expand Down
8 changes: 5 additions & 3 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#pragma once

#include <memory>
#include <sstream>

#include <torch/csrc/lazy/backend/backend_data.h>
Expand Down Expand Up @@ -49,6 +50,7 @@ class TORCH_API TorchMlirBackendData : public BackendData {
};

TorchMlirBackendData(BackendDevice device, Shape shape);
TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info);
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape);
Expand All @@ -59,10 +61,10 @@ class TORCH_API TorchMlirBackendData : public BackendData {

virtual bool HasValue() const override;

TorchMlirBackendData::Info* mlir_info() const;
BackendData::Info* mlir_info() const;

private:
std::unique_ptr<TorchMlirBackendData::Info> info_;
protected:
std::shared_ptr<BackendData::Info> info_;
};

class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
Expand Down
57 changes: 15 additions & 42 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
#include "torch-mlir-c/Registration.h"

#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
Expand All @@ -29,22 +26,6 @@
#include "utils/string_utils.h"
#include "utils/sys_utils.h"

namespace {

static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}


inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
MlirAttribute attr) {
MlirContext context = mlirAttributeGetContext(attr);
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
return mlirNamedAttributeGet(ident, attr);
}

}

namespace torch {
namespace lazy {

Expand Down Expand Up @@ -125,8 +106,6 @@ bool TorchMlirLoweringContext::CheckResultShape(
return false;
}



size_t TorchMlirLoweringContext::AddResult(const Output& output) {
PRINT_FUNCTION();

Expand Down Expand Up @@ -160,19 +139,7 @@ ComputationPtr TorchMlirLoweringContext::Build() {
MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp(
/*context=*/mlir_context_,
/*function=*/generate_jit_fn().get(),
/*getArgAttribute=*/[&](int index) -> MlirAttribute {
if (parameter_names_.count(index) == 0) {
return {nullptr};
}
MlirNamedAttribute attr = toMlirNamedAttribute(
"torch.parameter",
mlirStringAttrGet(
mlir_context_,
toMlirStringRef(parameter_names_[index])
)
);
return mlirDictionaryAttrGet(mlir_context_, 1, &attr);
},
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});

return std::make_shared<TorchMlirComputation>(
Expand Down Expand Up @@ -229,7 +196,8 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
torch::jit::Value* param =
graph_->addInput(c10::str("p", parameters_.size()));

auto info = mlir_data->mlir_info();
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
TORCH_CHECK(info, "Expected TorchMlirBackendData::Info");
if (info->scalar.has_value()) {
auto& scalar = info->scalar.value();
if (scalar.isFloatingPoint()) {
Expand Down Expand Up @@ -325,15 +293,16 @@ void TorchMlirLoweringContext::RegisterMlirDialects() {
TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<int, std::string> parameter_names,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), input_output_aliases_(input_output_aliases) {
graph_(graph), input_output_aliases_(input_output_aliases),
parameters_map_(parameters_map) {

num_parameters_ = graph_->inputs().size();

parameter_names_.reserve(parameter_names.size());
for (auto kv : parameter_names) {
parameter_names_.reserve(parameters_map_.size());
for (auto kv : parameters_map_) {
parameter_names_.emplace_back(kv.second);
}
}
Expand All @@ -353,6 +322,10 @@ const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
return parameter_names_;
}

const std::unordered_map<int, std::string>& TorchMlirComputation::parameters_map() const {
return parameters_map_;
}

const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
Expand All @@ -363,10 +336,10 @@ std::shared_ptr<torch::jit::Graph> TorchMlirComputation::graph() const {
return graph_;
}

MlirOperation& TorchMlirComputation::func_op() { return func_op_; }
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }

MlirContext* TorchMlirComputation::mlir_context() {
return &mlir_context_;
MlirContext TorchMlirComputation::mlir_context() const {
return mlir_context_;
}

const std::string TorchMlirComputation::debug_string() const {
Expand Down
11 changes: 6 additions & 5 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

#pragma once

#include <cstddef>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include <torch/csrc/api/include/torch/jit.h>
Expand Down Expand Up @@ -125,7 +123,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<int, std::string> parameter_names,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases);

int parameters_size() const override;
Expand All @@ -134,20 +132,23 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {

const std::vector<std::string>& parameter_names() const override;

const std::unordered_map<int, std::string>& parameters_map() const;

const torch::lazy::Shape& result_shape() const override;

std::shared_ptr<torch::jit::Graph> graph() const;

MlirOperation& func_op();
MlirOperation func_op() const;

MlirContext* mlir_context();
MlirContext mlir_context() const;

virtual const std::string debug_string() const;

virtual const std::string to_string() const override;

protected:
size_t num_parameters_;
std::unordered_map<int, std::string> parameters_map_;
std::vector<std::string> parameter_names_;
std::vector<Shape> parameter_shapes_;
Shape result_shape_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void DeviceData::propagate_name() {
// Add device data name to backend data
TorchMlirBackendData* mlir_data = dynamic_cast<TorchMlirBackendData*>(data_.get());
TORCH_CHECK(mlir_data);
TorchMlirBackendData::Info* info = mlir_data->mlir_info();
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
TORCH_CHECK(info);
info->name = name_;
}
Expand Down
14 changes: 6 additions & 8 deletions python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h>

#include "backend_impl.h"

Expand All @@ -28,8 +27,6 @@ using namespace torch::lazy;
namespace torch {
namespace lazy {



struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
: device_type_(device_type) {}
Expand All @@ -54,7 +51,6 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
std::cout << "RNG Seed Set to: " << seed << std::endl;
}


/**
* Lowering, Compilation, Execution
* */
Expand Down Expand Up @@ -112,15 +108,17 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
for (const auto& argument : arguments) {
const auto mlir_data =
std::static_pointer_cast<TorchMlirBackendData>(argument);
if (mlir_data->mlir_info()->scalar.has_value()) {
stack.emplace_back(mlir_data->mlir_info()->scalar.value());
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
TORCH_CHECK(info);
if (info->scalar.has_value()) {
stack.emplace_back(info->scalar.value());
} else {
at::Tensor tensor = mlir_data->mlir_info()->tensor;
at::Tensor tensor = info->tensor;
stack.emplace_back(tensor);
}

// count number of inputs
auto name = mlir_data->mlir_info()->name;
auto name = info->name;
if (startswith(name, "input_")) {
// Printing tensor name for testing purposes
std::cout << "Input tensor: " << name << std::endl;
Expand Down

0 comments on commit 3e4ec62

Please sign in to comment.