Skip to content

Commit

Permalink
[Iterators] Implement new ZipOp. (iree-org#651)
Browse files Browse the repository at this point in the history
* Add tablegen definition of new ZipOp.
* Add lowering to LLVM for new ZipOp.
  • Loading branch information
ingomueller-net authored Apr 6, 2023
1 parent f02c1ad commit 5d99d96
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,53 @@ def Iterators_ValueToStreamOp : Iterators_Op<"value_to_stream",
}];
}

class NonemptyVariadic<Type type> : Variadic<type> { let minSize = 1; }

def Iterators_ZipOp : Iterators_Op<"zip",
[AllMatch<[[{::llvm::ArrayRef(::llvm::SmallVector<Type>(
::llvm::map_range($inputs.getTypes(),
[](Type t) { return t.cast<StreamType>().getElementType(); }
)))}],
[{$result.getType().cast<StreamType>().getElementType()
.cast<::mlir::TupleType>().getTypes()}]],
"result stream must consist of tuples whose element types match "
"the element types of the input streams">,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Zips several streams into a stream of tuples.";
// TODO(ingomueller): Add attributes to this op or add other ops that allow to
// zip up until the longest input (by providing fill
// values), plus another one for equal lengths (which can
// assume that all inputs have the same length and thus
// use a cheaper termination test).
let description = [{
Reads one or more streams in lock step and produces a stream of tuples where
each struct constists of the elements of all input streams at the same
position of the streams. If the input streams do not have the same length,
then result stream is only as long as the shortest of the inputs and the
remainder of the other input streams is not consumed.

Example:
```mlir
%zipped = iterators.zip %input1, %input2) :
(!iterators.stream<i32>, !iterators.stream<i64>)
-> (!iterators.stream<tuple<i32, i64>>)
```
}];
let arguments = (ins
NonemptyVariadic<Iterators_Stream>:$inputs
);
let results = (outs Iterators_StreamOf<AnyTuple>:$result);
let assemblyFormat =
"$inputs attr-dict `:` functional-type($inputs, $result)";
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
setNameFn(getResult(), "zipped");
}
}];
}

//===----------------------------------------------------------------------===//
// Ops related to Iterator bodies.
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ StateType StateTypeComputer::operator()(
return StateType::get(context, {hasReturned, valueType});
}

/// The state of ZipOp consists of the states of its upstream iterators,
/// i.e., the state of the iterators that produce its input streams.
template <>
StateType
StateTypeComputer::operator()(ZipOp op,
llvm::SmallVector<StateType> upstreamStateTypes) {
MLIRContext *context = op->getContext();
llvm::SmallVector<Type> upstreamTypes(upstreamStateTypes.begin(),
upstreamStateTypes.end());
return StateType::get(context, upstreamTypes);
}

/// Build IteratorInfo, assigning new unique names as needed. Takes the
/// `StateType` as a parameter, to ensure proper build order (all uses are
/// visited before any def).
Expand Down Expand Up @@ -173,7 +185,8 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
MapOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp
ValueToStreamOp,
ZipOp
// clang-format on
>([&](auto op) {
llvm::SmallVector<StateType> upstreamStateTypes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,172 @@ static Value buildStateCreation(ValueToStreamOp op,
return b.create<CreateStateOp>(stateType, ValueRange{hasReturned, value});
}

//===----------------------------------------------------------------------===//
// ZipOp.
//===----------------------------------------------------------------------===//

/// Builds IR that opens all upstream iterators. Possible output (for one input
/// stream):
///
/// %upstream_state = iterators.extractvalue %initialState[0] :
/// !iterators.state<!upstream_state_type>
/// %updated_upstream_state =
/// call @iterators.upstream.open.0(%upstream_state) :
/// (upstream_state_type) -> upstream_state_type
/// %state = iterators.insertvalue %updated_upstream_state
/// into %initialState[0] : !iterators.state<upstream_state_type>
static Value buildOpenBody(ZipOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

// Open each upstream.
Value updatedState = initialState;
for (auto [index, upstreamInfo] : llvm::enumerate(upstreamInfos)) {
Type upstreamStateType = upstreamInfo.stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, updatedState, b.getIndexAttr(index));

// Call Open on upstream.
SymbolRefAttr openFunc = upstreamInfo.openFunc;
auto openCallOp = b.create<func::CallOp>(openFunc, upstreamStateType,
initialUpstreamState);

// Update state.
Value updatedUpstreamState = openCallOp->getResult(0);
updatedState = b.create<iterators::InsertValueOp>(
updatedState, b.getIndexAttr(index), updatedUpstreamState);
}

return updatedState;
}

/// Builds IR that calls next on all upstream iterators and assembles an output
/// element from the resulting elements. Pseudo-code:
///
/// nextElement = tuple()
/// for each upstream in upstreams:
/// if nextUpstreamElement = upstream->Next():
/// nextElement.append(nextUpstreamElement)
/// else:
/// return {}
/// return nextElement
///
/// Possible output (for one input stream):
///
/// %1 = iterators.extractvalue %initialState[0] :
/// !iterators.state<!upstream_state_type>
/// %2:3 = call @iterators.map.next.0(%1) :
/// (!upstream_state_type) -> (!upstream_state_type, i1, i32)
/// %3 = arith.andi %true, %2#1 : i1
/// %state = iterators.insertvalue %2#0 into %initialState[0] :
/// !iterators.state<!upstream_state_type>
/// %tuple = tuple.from_elements %2#2 : tuple<i32>
/// return %state, %3, %4 :
/// !iterators.state<!!upstream_state_type>, i1, tuple<i32>
static llvm::SmallVector<Value, 4>
buildNextBody(ZipOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Type i1 = b.getI1Type();

// Call next on each upstream.
Value updatedState = initialState;
Value hasNext = b.create<arith::ConstantIntOp>(/*value=*/1, /*width=*/1);
SmallVector<Value> upstreamElements;
for (auto [index, upstreamInfo] : llvm::enumerate(upstreamInfos)) {
Type upstreamStateType = upstreamInfo.stateType;
auto inputStreamType = op->getOperand(index).getType().cast<StreamType>();
Type inputElementType = inputStreamType.getElementType();

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, updatedState, b.getIndexAttr(index));

// Call next on upstream.
SmallVector<Type> nextResultTypes = {upstreamStateType, i1,
inputElementType};
SymbolRefAttr nextFunc = upstreamInfo.nextFunc;
auto nextCall =
b.create<func::CallOp>(nextFunc, nextResultTypes, initialUpstreamState);
Value upstreamHasNext = nextCall->getResult(1);
Value upstreamElement = nextCall->getResult(2);

// Combine hasNext value of the call with previous ones.
hasNext = b.create<arith::AndIOp>(hasNext, upstreamHasNext);

// Remember upstream element.
upstreamElements.push_back(upstreamElement);

// Update state.
Value updatedUpstreamState = nextCall->getResult(0);
updatedState = b.create<iterators::InsertValueOp>(
updatedState, b.getIndexAttr(index), updatedUpstreamState);
}

// Assemble tuple from upstream elements;
auto tupleType = elementType.cast<TupleType>();
Value nextElement =
b.create<tuple::FromElementsOp>(tupleType, upstreamElements);

return {updatedState, hasNext, nextElement};
}

/// Builds IR that closes all upstream iterators. Possible output (for one input
/// stream):
///
/// %upstream_state = iterators.extractvalue %initialState[0] :
/// !iterators.state<!upstream_state_type>
/// %updated_upstream_state =
/// call @iterators.upstream.close.0(%upstream_state) :
/// (upstream_state_type) -> upstream_state_type
/// %state = iterators.insertvalue %updated_upstream_state
/// into %initialState[0] : !iterators.state<upstream_state_type>
static Value buildCloseBody(ZipOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {

Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

// Close each upstream.
Value updatedState = initialState;
for (auto [index, upstreamInfo] : llvm::enumerate(upstreamInfos)) {
Type upstreamStateType = upstreamInfo.stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, updatedState, b.getIndexAttr(index));

// Call close on upstream.
SymbolRefAttr closeFunc = upstreamInfo.closeFunc;
auto closeCallOp = b.create<func::CallOp>(closeFunc, upstreamStateType,
initialUpstreamState);

// Update state.
Value updatedUpstreamState = closeCallOp->getResult(0);
updatedState = b.create<iterators::InsertValueOp>(
updatedState, b.getIndexAttr(index), updatedUpstreamState);
}

return updatedState;
}

/// Builds IR that initializes the iterator state with the upstream iterators
/// states. Possible output (for one input stream):
///
/// %state = iterators.createstate(%upstream_state) :
/// !iterators.state<!!upstream_state_type>
static Value buildStateCreation(ZipOp op, ZipOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
ValueRange upstreamStates = adaptor.getInputs();
return b.create<CreateStateOp>(stateType, upstreamStates);
}

//===----------------------------------------------------------------------===//
// Helpers for creating Open/Next/Close functions and state creation.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1382,7 +1548,8 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
MapOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp
ValueToStreamOp,
ZipOp
// clang-format on
>([&](auto op) {
return buildOpenBody(op, builder, initialState, upstreamInfos);
Expand All @@ -1401,7 +1568,8 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
MapOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp
ValueToStreamOp,
ZipOp
// clang-format on
>([&](auto op) {
return buildNextBody(op, builder, initialState, upstreamInfos,
Expand All @@ -1421,7 +1589,8 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
MapOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp
ValueToStreamOp,
ZipOp
// clang-format on
>([&](auto op) {
return buildCloseBody(op, builder, initialState, upstreamInfos);
Expand All @@ -1439,7 +1608,8 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
MapOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp
ValueToStreamOp,
ZipOp
// clang-format on
>([&](auto op) {
using OpAdaptor = typename decltype(op)::Adaptor;
Expand Down
77 changes: 77 additions & 0 deletions experimental/iterators/test/Conversion/IteratorsToLLVM/zip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// RUN: iterators-opt %s -convert-iterators-to-llvm \
// RUN: | FileCheck --enable-var-scope %s

// CHECK-LABEL: func.func private @iterators.zip.close.{{[0-9]+}}(
// CHECK-SAME: %[[arg0:.*]]: !iterators.state<[[lhsUpstreamStateType:!iterators\.state.*]], [[rhsUpstreamStateType:!iterators.state.*]]>) ->
// CHECK-SAME: !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]> {
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %arg0[0] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V1:.*]] = call @iterators.{{.*}}.close.{{[0-9]+}}(%[[V0]]) : ([[lhsUpstreamStateType]]) -> [[lhsUpstreamStateType]]
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %arg0[0] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V3:.*]] = iterators.extractvalue %[[V2]][1] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V4:.*]] = call @iterators.{{.*}}.close.{{[0-9]+}}(%[[V3]]) : ([[rhsUpstreamStateType]]) -> [[rhsUpstreamStateType]]
// CHECK-NEXT: %[[V5:.*]] = iterators.insertvalue %[[V4]] into %[[V2]][1] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: return %[[V5]] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: }

// CHECK-LABEL: func.func private @iterators.zip.next.{{[0-9]+}}(
// CHECK-SAME: %[[arg0:.*]]: !iterators.state<[[lhsUpstreamStateType:!iterators\.state.*]], [[rhsUpstreamStateType:!iterators.state.*]]>) ->
// CHECK-SAME: (!iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>, i1, tuple<i32, i32>) {
// CHECK-NEXT: %[[V0:.*]] = arith.constant true
// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %arg0[0] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V3:.*]]:3 = call @iterators.{{.*}}.next.{{[0-9]+}}(%[[V2]]) : ([[lhsUpstreamStateType]]) -> ([[lhsUpstreamStateType]], i1, i32)
// CHECK-NEXT: %[[V4:.*]] = arith.andi %[[V0]], %[[V3]]#1 : i1
// CHECK-NEXT: %[[V6:.*]] = iterators.insertvalue %[[V3]]#0 into %arg0[0] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V7:.*]] = iterators.extractvalue %[[V6]][1] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V8:.*]]:3 = call @iterators.{{.*}}.next.{{[0-9]+}}(%[[V7]]) : ([[rhsUpstreamStateType]]) -> ([[rhsUpstreamStateType]], i1, i32)
// CHECK-NEXT: %[[V9:.*]] = arith.andi %[[V4]], %[[V8]]#1 : i1
// CHECK-NEXT: %[[Vb:.*]] = iterators.insertvalue %[[V8]]#0 into %[[V6]][1] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[Vc:.*]] = tuple.from_elements %[[V3]]#2, %[[V8]]#2 : tuple<i32, i32>
// CHECK-NEXT: return %[[Vb]], %[[V9]], %[[Vc]] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>, i1, tuple<i32, i32>
// CHECK-NEXT: }

// CHECK-LABEL: func.func private @iterators.zip.open.{{[0-9]+}}(
// CHECK-SAME: %[[arg0:.*]]: !iterators.state<[[lhsUpstreamStateType:!iterators\.state.*]], [[rhsUpstreamStateType:!iterators.state.*]]>) ->
// CHECK-SAME: !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]> {
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %arg0[0] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V1:.*]] = call @iterators.{{.*}}.open.{{[0-9]+}}(%[[V0]]) : ([[lhsUpstreamStateType]]) -> [[lhsUpstreamStateType]]
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %arg0[0] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V3:.*]] = iterators.extractvalue %[[V2]][1] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: %[[V4:.*]] = call @iterators.{{.*}}.open.{{[0-9]+}}(%[[V3]]) : ([[rhsUpstreamStateType]]) -> [[rhsUpstreamStateType]]
// CHECK-NEXT: %[[V5:.*]] = iterators.insertvalue %[[V4]] into %[[V2]][1] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: return %[[V5]] : !iterators.state<[[lhsUpstreamStateType]], [[rhsUpstreamStateType]]>
// CHECK-NEXT: }

func.func private @unpack_i32(%input : tuple<i32>) -> i32 {
%i = tuple.to_elements %input : tuple<i32>
return %i : i32
}

func.func @main() {
// CHECK-LABEL: func.func @main() {
// Left-hand stream of numbers.
%zero_to_three = "iterators.constantstream"()
{ value = [[0 : i32], [1 : i32], [2 : i32], [3 : i32]] }
: () -> (!iterators.stream<tuple<i32>>)
// CHECK: %[[lhsInnerState:.*]] = iterators.createstate
%unpacked_lhs = "iterators.map"(%zero_to_three) {mapFuncRef = @unpack_i32}
: (!iterators.stream<tuple<i32>>) -> (!iterators.stream<i32>)
// CHECK: %[[lhsOuterState:.*]] = iterators.createstate(%[[lhsInnerState]]) : [[lhsStateType:.*]]

// Right-hand stream of numbers.
%four_to_seven = "iterators.constantstream"()
{ value = [[4 : i32], [5 : i32], [6 : i32], [7 : i32]] }
: () -> (!iterators.stream<tuple<i32>>)
// CHECK: %[[rhsInnerState:.*]] = iterators.createstate
%unpacked_rhs = "iterators.map"(%four_to_seven) {mapFuncRef = @unpack_i32}
: (!iterators.stream<tuple<i32>>) -> (!iterators.stream<i32>)
// CHECK: %[[rhsOuterState:.*]] = iterators.createstate(%[[rhsInnerState]]) : [[rhsStateType:.*]]

// Zip.
%zipped = iterators.zip %unpacked_lhs, %unpacked_rhs :
(!iterators.stream<i32>, !iterators.stream<i32>)
-> (!iterators.stream<tuple<i32, i32>>)
// CHECK-NEXT: %[[state:.*]] = iterators.createstate(%[[lhsOuterState]], %[[rhsOuterState]]) : !iterators.state<[[lhsStateType]], [[rhsStateType]]>
return
// CHECK-NEXT: return
}
// CHECK-NEXT: }
15 changes: 15 additions & 0 deletions experimental/iterators/test/Dialect/Iterators/zip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: iterators-opt %s \
// RUN: | FileCheck %s

func.func @main(%stream_i32 : !iterators.stream<i32>,
%stream_i64 : !iterators.stream<i64>) {
// CHECK-LABEL: func.func @main(
// CHECK-SAME: %[[arg0:.*]]: !iterators.stream<i32>, %[[arg1:.*]]: !iterators.stream<i64>) {
%zipped = iterators.zip %stream_i32, %stream_i64 :
(!iterators.stream<i32>, !iterators.stream<i64>)
-> !iterators.stream<tuple<i32, i64>>
// CHECK-NEXT: %[[V0:zipped.*]] = iterators.zip %[[arg0]], %[[arg1]] : (!iterators.stream<i32>, !iterators.stream<i64>) -> !iterators.stream<tuple<i32, i64>>
return
// CHECK-NEXT: return
}
// CHECK-NEXT: }
Loading

0 comments on commit 5d99d96

Please sign in to comment.