Skip to content

Commit

Permalink
Add !torch.tuple<T1, T2> type.
Browse files Browse the repository at this point in the history
This further eliminates the need for the `basicpy` dependency.

This required adding `torch.prim.TupleConstruct` to replace
`basicpy.build_tuple`.
  • Loading branch information
silvasean committed Jun 15, 2021
1 parent ea1dd1c commit 92ee0fa
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 34 deletions.
2 changes: 1 addition & 1 deletion external/llvm-project
Submodule llvm-project updated 1652 files
13 changes: 8 additions & 5 deletions frontends/pytorch/csrc/builder/ivalue_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
if (ivalue.isTuple()) {
auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> elems;
std::vector<MlirValue> operands;
std::vector<MlirType> types;
for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem));
MlirValue operand = importIValue(elem);
operands.push_back(operand);
types.push_back(mlirValueGetType(operand));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
npcompBasicpyTupleTypeGet(context), elems);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.TupleConstruct", loc,
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
Expand Down
10 changes: 2 additions & 8 deletions frontends/pytorch/csrc/builder/node_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
switch (kind) {
case c10::prim::ListUnpack:
case c10::prim::ListConstruct:
case c10::prim::TupleConstruct: {
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()));
return;
}
case c10::prim::GetAttr:
case c10::prim::SetAttr: {
createAndMapNodeWithAttribute(
Expand All @@ -96,14 +98,6 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
}

// Ops trivially lowered through `basicpy` dialect.
switch (kind) {
case c10::prim::TupleConstruct: {
createAndMapTrivialNode(node, "basicpy.build_tuple");
return;
}
}

if (kind == c10::prim::Constant) {
auto output = node->output();
MlirOperation op;
Expand Down
9 changes: 7 additions & 2 deletions frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,13 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::TupleType: {
// TODO: Don't lose the element type information.
return npcompBasicpyTupleTypeGet(context);
std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type :
torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back(mapFromTorchType(loc, type));
}
return npcompTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::StringType: {
return npcompBasicpyBytesTypeGet(context);
Expand Down
4 changes: 2 additions & 2 deletions frontends/pytorch/test/ivalue_import/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __init__(self):
# CHECK: }
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
# CHECK: %[[TUPLE:.*]] = basicpy.build_tuple %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.TupleType
# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[N1]], %[[N2]] : i64, i64
# CHECK: torch.nn_module {
# CHECK: torch.slot "t", %[[TUPLE]] : !basicpy.TupleType
# CHECK: torch.slot "t", %[[TUPLE]] : !torch.tuple<i64, i64>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">


Expand Down
24 changes: 12 additions & 12 deletions frontends/pytorch/test/node_import/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return i

# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<i64, i64>) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<i64, i64> -> i64, i64
# CHECK: return %[[RET]]#0 : i64
@mb.import_function
@torch.jit.script
Expand All @@ -75,12 +75,12 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
return val

# CHECK-LABEL: func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
# CHECK: return %[[RET]] : i64
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<!torch.tensor, !torch.tensor>) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<!torch.tensor, !torch.tensor>, i64 -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor
@mb.import_function
@torch.jit.script
def prim_TupleIndex(tup: typing.Tuple[int, int]):
def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]):
return tup[0]

# CHECK-LABEL: func @__torch__.prim_ListUnpack(
Expand Down Expand Up @@ -121,28 +121,28 @@ def prim_device(x):
return x.device

# CHECK-LABEL: func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tuple<i64, i64, i64> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
# CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MIN1]], %[[MIN2]], %[[MIN3]] : i64, i64, i64
# CHECK: return %[[RET]] : !torch.tuple<i64, i64, i64>
@mb.import_function
@torch.jit.script
def prim_min(x: int):
return min(x), min(x,x), min(x, x, x)

# CHECK-LABEL: func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tuple<i64, i64, i64> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
# CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MAX1]], %[[MAX2]], %[[MAX3]] : i64, i64, i64
# CHECK: return %[[RET]] : !torch.tuple<i64, i64, i64>
@mb.import_function
@torch.jit.script
def prim_max(x: int):
Expand Down
6 changes: 3 additions & 3 deletions frontends/pytorch/test/node_import/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.TupleType {
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : !torch.tensor, !torch.tensor
# CHECK: return %[[RET]] : !torch.tuple<!torch.tensor, !torch.tensor>

@mb.import_function
@torch.jit.script
Expand Down
12 changes: 12 additions & 0 deletions include/npcomp-c/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ bool npcompTypeIsATorchOptional(MlirType t);
/// Gets the !torch.optional<T> type with subtype T.
MlirType npcompTorchOptionalTypeGet(MlirType containedType);

//===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//

/// Checks whether the given type is a !torch.tuple type
bool npcompTypeIsATorchTuple(MlirType t);

/// Gets the !torch.tuple type with contained types `containedTypes`.
MlirType npcompTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes);

//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions include/npcomp/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,29 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
}];
}

def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
NoSideEffect,
TypesMatchWith<"contained types correspond to operand types",
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))">
]> {
let summary = "TorchScript prim::TupleConstruct op";
let description = [{
Note: This op does not allow trivial type refinement, because the
operand types and the result types must be in correspondence.
}];

let arguments = (ins
Variadic<AnyTorchType>:$elements
);
let results = (outs
Torch_TupleType:$result
);

let assemblyFormat = [{
$elements attr-dict `:` type($elements)
}];
}

def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
NoSideEffect,
AllowsTypeRefinement,
Expand Down
12 changes: 11 additions & 1 deletion include/npcomp/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ def Torch_ListType : Torch_TypeWithContainedType<"List", "list"> {
}];
}

def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
let summary = "!torch.tuple<T1, T2, T3>";
let description = [{
Tuple type with 0-N ordered contained types.
}];
let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
);
}

def Torch_DeviceType : Torch_Type<"Device", "Device"> {
let summary = "Torch device";
}
Expand Down Expand Up @@ -329,7 +339,7 @@ def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_TupleType,
Torch_TupleType,
Basicpy_BytesType,
Torch_NnModuleType,
Torch_NoneType,
Expand Down
18 changes: 18 additions & 0 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
}

//===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//

bool npcompTypeIsATorchTuple(MlirType t) {
return unwrap(t).isa<Torch::TupleType>();
}

MlirType npcompTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
return wrap(Torch::TupleType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}

//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,40 @@
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/STLExtras.h"

using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//

Type Torch::TupleType::parse(MLIRContext *context, DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
if (!parser.parseOptionalGreater())
return Torch::TupleType::get(context, {});

SmallVector<Type> containedTypes;
do {
Type containedType;
if (parser.parseType(containedType))
return Type();
containedTypes.push_back(containedType);
} while (!parser.parseOptionalComma());
if (parser.parseGreater())
return Type();
return Torch::TupleType::get(context, containedTypes);
}

void Torch::TupleType::print(::mlir::DialectAsmPrinter &printer) const {
printer << "tuple<";
llvm::interleaveComma(getContainedTypes(), printer);
printer << ">";
}

//===----------------------------------------------------------------------===//
// BaseTensorType
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/Torch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
// CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>

// CHECK: @tuple.empty() -> !torch.tuple<>
func private @tuple.empty() -> !torch.tuple<>
// CHECK: @tuple.one_element() -> !torch.tuple<!torch.tensor>
func private @tuple.one_element() -> !torch.tuple<!torch.tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
func private @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>

// CHECK-LABEL: func @torch.tensor() {
func @torch.tensor() {
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
Expand Down

0 comments on commit 92ee0fa

Please sign in to comment.