Skip to content

Commit

Permalink
Miscellaneous api compat fixes v2 (#63)
Browse files Browse the repository at this point in the history
* implemented cdef for permute()

* passing PyTorch_IntValue then converting to tupe to cast into PyAnyTorchListOfTensorValue

* fixed wrong type casting within lambda block

* implemented vector_norm that accepts int for dimension

* allowed sum.dim_IntList func to accept a single int as an argument

* softplus func - default value added

* added SoftPlusOP in TorchOps.cpp and removed autogen of SoftPlusOp
  • Loading branch information
brucekimrokcmu authored Jul 13, 2023
1 parent 06ef686 commit fa1c3a2
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 27 deletions.
73 changes: 73 additions & 0 deletions cpp_ext/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,25 @@ PyAnyTorchListOfTensorValue chunk(const PyAnyTorchTensorValue &self,
return list;
}

// aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)
PyAnyTorchTensorValue softplus(const PyAnyTorchTensorValue &self,
const PyAnyTorchScalarValue &beta,
const PyAnyTorchScalarValue &threshold__,
PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.softplus";
std::vector<PyType> _returnTypes = {
PyAnyTorchTensorType::getWithLeastStaticInformation(
loc->getContext().get())};
std::vector<std::reference_wrapper<const PyType>> returnTypes;
for (const auto &returnType : _returnTypes)
returnTypes.push_back(returnType);
PyOperationRef opRef =
createOperation(operationName, returnTypes, {self, beta, threshold__},
/*attributes=*/{}, loc, ip);
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}

void populateTorchMLIROps(py::module &m) {
py::register_exception_translator([](std::exception_ptr p) {
try {
Expand Down Expand Up @@ -210,6 +229,7 @@ void populateTorchMLIROps(py::module &m) {
"dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)
m.def(
"vector_norm",
[](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &ord,
Expand All @@ -224,6 +244,22 @@ void populateTorchMLIROps(py::module &m) {
"dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::linalg_vector_norm : (Tensor, Scalar, int, bool, int?) -> (Tensor)
m.def(
"vector_norm",
[](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &ord,
const PyTorch_IntValue &dim, const PyTorch_BoolValue &keepdim,
const PyAnyTorchOptionalIntValue &dtype, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
auto dims = PyAnyTorchOptionalListOfTorchIntValue(py::make_tuple(dim));

return linalg_vector_norm(self, ord, dims, keepdim, dtype, loc.get(),
ip.get());
},
"self"_a, "ord"_a = 2, "dim"_a = py::none(), "keepdim"_a = false,
"dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::chunk : (Tensor, int, int) -> (Tensor[])
m.def(
"chunk",
Expand All @@ -234,6 +270,43 @@ void populateTorchMLIROps(py::module &m) {
},
"self"_a, "chunks"_a, "dim"_a = 0, py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::amax : (Tensor, int, bool) -> (Tensor)
m.def(
"amax",
[](const PyAnyTorchTensorValue &self, const PyTorch_IntValue &dim,
const PyTorch_BoolValue &keepdim, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
auto dims = PyAnyTorchListOfTorchIntValue(py::make_tuple(dim));
return amax(self, dims, keepdim, loc.get(), ip.get());
},
"self"_a, "dim"_a, "keepdim"_a = false, py::kw_only(),
"loc"_a = py::none(), "ip"_a = py::none());

// aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)
m.def(
"sum",
[](const PyAnyTorchTensorValue &self, const PyTorch_IntValue &dim,
const PyTorch_BoolValue &keepdim,
const PyAnyTorchOptionalIntValue &dtype, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
auto dims = PyAnyTorchListOfTorchIntValue(py::make_tuple(dim));
return sum(self, dims, keepdim, dtype, loc.get(), ip.get());
},
"self"_a, "dim"_a = py::none(), "keepdim"_a = false,
"dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)
m.def(
"softplus",
[](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &beta,
const PyAnyTorchScalarValue &threshold__, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
return softplus(self, beta, threshold__, loc.get(), ip.get());
},
"self"_a, "beta"_a = 1, "threshold__"_a = 20, py::kw_only(),
"loc"_a = py::none(), "ip"_a = py::none());
}

} // namespace mlir::torch
6 changes: 6 additions & 0 deletions cpp_ext/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ PyAnyTorchListOfTensorValue chunk(const PyAnyTorchTensorValue &self,
const PyTorch_IntValue &dim, PyLocation *loc,
PyInsertionPoint *ip);

// aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)
PyAnyTorchTensorValue softplus(const PyAnyTorchTensorValue &self,
const PyAnyTorchScalarValue &beta,
const PyAnyTorchScalarValue &threshold__,
PyLocation *loc, PyInsertionPoint *ip);

void populateTorchMLIROps(py::module &m);

} // namespace mlir::torch
Expand Down
16 changes: 0 additions & 16 deletions cpp_ext/TorchOps.impls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5646,22 +5646,6 @@ PyAnyTorchTensorValue softmax(const PyAnyTorchTensorValue &self, const PyTorch_I
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}
// aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)
PyAnyTorchTensorValue softplus(const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &beta, const PyAnyTorchScalarValue &threshold__, PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.softplus";
std::vector<PyType> _returnTypes = {PyAnyTorchTensorType::getWithLeastStaticInformation(loc->getContext().get())};
std::vector<std::reference_wrapper<const PyType>> returnTypes;
for (const auto& returnType : _returnTypes)
returnTypes.push_back(returnType);
PyOperationRef opRef = createOperation(operationName,
returnTypes,
{self, beta, threshold__},
/*attributes=*/{},
loc,
ip);
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}
// aten::sort.int : (int[], bool) -> ()
void sort(const PyAnyTorchListOfTorchIntValue &self, const PyTorch_BoolValue &reverse, PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.sort.int";
Expand Down
3 changes: 0 additions & 3 deletions cpp_ext/TorchOps.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1058,9 +1058,6 @@ PyAnyTorchTensorValue slice(const PyAnyTorchTensorValue &self, const PyTorch_Int
// aten::softmax.int : (Tensor, int, int?) -> (Tensor)
PyAnyTorchTensorValue softmax(const PyAnyTorchTensorValue &self, const PyTorch_IntValue &dim, const PyAnyTorchOptionalIntValue &dtype, PyLocation *loc, PyInsertionPoint *ip);

// aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)
PyAnyTorchTensorValue softplus(const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &beta, const PyAnyTorchScalarValue &threshold__, PyLocation *loc, PyInsertionPoint *ip);

// aten::sort.int : (int[], bool) -> ()
void sort(const PyAnyTorchListOfTorchIntValue &self, const PyTorch_BoolValue &reverse, PyLocation *loc, PyInsertionPoint *ip);

Expand Down
3 changes: 0 additions & 3 deletions cpp_ext/TorchOps.pybinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,9 +1058,6 @@ m.def("slice", [](const PyAnyTorchTensorValue &self, const PyTorch_IntValue &dim
// aten::softmax.int : (Tensor, int, int?) -> (Tensor)
m.def("softmax", [](const PyAnyTorchTensorValue &self, const PyTorch_IntValue &dim, const PyAnyTorchOptionalIntValue &dtype, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return softmax(self, dim, dtype, loc.get(), ip.get()); }, "self"_a, "dim"_a, "dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)
m.def("softplus", [](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &beta, const PyAnyTorchScalarValue &threshold__, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return softplus(self, beta, threshold__, loc.get(), ip.get()); }, "self"_a, "beta"_a = 1, "threshold__"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::sort.int : (int[], bool) -> ()
m.def("sort", [](const PyAnyTorchListOfTorchIntValue &self, const PyTorch_BoolValue &reverse, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> void { return sort(self, reverse, loc.get(), ip.get()); }, "self"_a, "reverse"_a = false, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

Expand Down
10 changes: 10 additions & 0 deletions cpp_ext/TorchTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ void PyAnyTorchTensorValue::bindDerived(ClassTy &c) {
"memory_format"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

c.def(
"permute",
[](const PyAnyTorchTensorValue &self, const py::args &dims,
DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
return permute(self, PyAnyTorchListOfTorchIntValue(dims), loc.get(),
ip.get());
},
py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

#include "TorchTensor.pybinds.cpp"
}

Expand Down
4 changes: 0 additions & 4 deletions cpp_ext/TorchTensor.pybinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1752,10 +1752,6 @@ c.def("ormqr", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs)
// outer(self, vec2: Tensor) -> Tensor
c.def("outer", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: outer with signature outer(self, vec2: Tensor) -> Tensor"); });

// @overload permute(self, dims: _size) -> Tensor
// aten::permute : (Tensor, int[]) -> (Tensor)
c.def("permute", [](const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &dims, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return permute(self, dims, loc.get(), ip.get()); }, "dims"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor
c.def("pin_memory", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: pin_memory with signature pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor"); });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_clean_name(name):
"AnyTorchType",
"anonymous_430",
}
SKIP_OPS = {"Torch_PrimsSqrtOp", "Torch_AtenChunkOp"}
SKIP_OPS = {"Torch_PrimsSqrtOp", "Torch_AtenChunkOp", "Torch_AtenSoftplusOp",}
SKIP_TENSOR_BINDS = {
"@overload view(self, dtype: _dtype) -> Tensor",
"@overload view(self, size: Sequence[Union[_int, SymInt]]) -> Tensor",
Expand All @@ -63,6 +63,7 @@ def get_clean_name(name):
"chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]",
"__getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor",
"double(self) -> Tensor",
"@overload permute(self, dims: _size) -> Tensor",
}

TORCH_OPS_IMPL_CPP = "TorchOps.impls.cpp"
Expand Down

0 comments on commit fa1c3a2

Please sign in to comment.