Skip to content

Commit c357983

Browse files
authored
fix: DecomposeGenericByUnfoldingPermutation pass (#19)
The following PR is attempting to: 1. extend the decompose pass to also reason about scalars which can appear on generic functions 2. Reject any operands that are not RankedTensors. As this pass was created with this assumption. Though this pass can be naturally called on linalg with memref or any type.
1 parent 132b0a8 commit c357983

File tree

2 files changed

+101
-10
lines changed

2 files changed

+101
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
166166
// out which operand can supply that runtime-value (tensor.dim).
167167
// Leaving it as a future TODO.
168168
if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
169-
auto opType = cast<RankedTensorType>(oper.get().getType());
169+
// Allow scalar values as these can be broadcasted on the input.
170+
if (oper.get().getType().isIntOrFloat())
171+
return false;
172+
// If any of the operands are not a RankedTensorType, then we should
173+
// return early. The pattern has been built with RankedTensors in mind.
174+
if (!isa<RankedTensorType>(oper.get().getType()))
175+
return true;
176+
auto opType = cast<ShapedType>(oper.get().getType());
170177
return ShapedType::isDynamicShape(opType.getShape());
171178
}))
172179
return failure();
@@ -181,10 +188,27 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
181188
// Walk over each input operand and unfold if it is transposed, broadcast
182189
// or mix of two via operand's affine-map.
183190
for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
184-
auto &map = newMap[i];
185-
auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
186-
auto elType = inputRTType.getElementType();
191+
auto inputType = newInitValues[i].getType();
192+
SmallVector<int64_t> inputShape =
193+
llvm::TypeSwitch<Type, SmallVector<int64_t>>(inputType)
194+
.Case([](RankedTensorType tensor) { return tensor.getShape(); })
195+
.Case([](FloatType scalar) { return SmallVector<int64_t>({1}); })
196+
.Case([](IntegerType scalar) { return SmallVector<int64_t>({1}); })
197+
.Default([](Type) { return SmallVector<int64_t>(); });
198+
199+
Type elType = llvm::TypeSwitch<Type, Type>(inputType)
200+
.Case([](RankedTensorType tensor) {
201+
return tensor.getElementType();
202+
})
203+
.Case([](FloatType scalar) { return scalar; })
204+
.Case([](IntegerType scalar) { return scalar; })
205+
.Default([](Type) { return Type(); });
206+
207+
// If we were not able to result the information skip.
208+
if (inputShape.empty() || !elType)
209+
continue;
187210

211+
auto &map = newMap[i];
188212
/// Nothing to do if map is already an identity.
189213
if (map.isIdentity())
190214
continue;
@@ -197,7 +221,7 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
197221
/// rule: dim(result, i) = dim(input, permutation[i])
198222
SmallVector<int64_t> transposedShape(map.getNumResults());
199223
for (int64_t i = 0; i < map.getNumResults(); ++i)
200-
transposedShape[i] = inputRTType.getShape()[permutation[i]];
224+
transposedShape[i] = inputShape[permutation[i]];
201225

202226
Value emptyTensor =
203227
rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
@@ -211,13 +235,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
211235
// Does it require broadcast?
212236
if (!broadcastedDims.empty()) {
213237
assert(broadcastedDims.size() && "should have non size broadcast");
214-
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
215-
loc, outputShape, inputRTType.getElementType());
238+
Value emptyTensor =
239+
rewriter.create<tensor::EmptyOp>(loc, outputShape, elType);
216240

217-
auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
218-
loc, newInitValues[i], emptyTensor, broadcastedDims);
241+
Value source = newInitValues[i];
242+
Value result;
243+
// If a scalar is being broadcasted we can simply use a fill operation.
244+
if (source.getType().isIntOrFloat()) {
245+
result = rewriter.create<linalg::FillOp>(loc, source, emptyTensor)
246+
->getResult(0);
247+
} else {
248+
result = rewriter
249+
.create<linalg::BroadcastOp>(loc, source, emptyTensor,
250+
broadcastedDims)
251+
->getResult(0);
252+
}
219253

220-
newInitValues[i] = broadcastOp->getResult(0);
254+
newInitValues[i] = result;
221255
isChanged = true;
222256
}
223257
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());

mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,60 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z :
6969
// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
7070
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
7171
// CHECK-NOT: linalg.generic
72+
73+
74+
// -----
75+
76+
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
77+
#broadcast = affine_map<(d0, d1, d2) -> ()>
78+
func.func @scalar_broadcast(%x: tensor<1x8x16xf32>, %y: f32) -> tensor<1x8x16xf32> {
79+
%empty = tensor.empty() : tensor<1x8x16xf32>
80+
%res = linalg.generic
81+
{ indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
82+
ins(%x, %y : tensor<1x8x16xf32>, f32)
83+
outs(%empty : tensor<1x8x16xf32>) {
84+
^bb0(%in: f32, %in2: f32, %out: f32):
85+
%add = arith.addf %in, %in2 : f32
86+
linalg.yield %add : f32
87+
} -> tensor<1x8x16xf32>
88+
return %res : tensor<1x8x16xf32>
89+
}
90+
91+
// CHECK-LABEL: scalar_broadcast
92+
// CHECK-SAME: %[[INPUT:.+]]: tensor<1x8x16xf32>
93+
// CHECK-SAME: %[[SCALAR:.+]]: f32
94+
// CHECK-DAG: %[[EMPTY_ADD:.+]] = tensor.empty() : tensor<1x8x16xf32>
95+
// CHECK-DAG: %[[EMPTY_FILL:.+]] = tensor.empty() : tensor<1x8x16xf32>
96+
// CHECK-DAG: %[[FILL:.+]] = linalg.fill
97+
// CHECK-SAME: ins(%[[SCALAR]] : f32)
98+
// CHECK-SAME: outs(%[[EMPTY_FILL]] : tensor<1x8x16xf32>)
99+
// CHECK: %[[ADD:.+]] = linalg.add
100+
// CHECK-SAME: ins(%[[INPUT]], %[[FILL]] : tensor<1x8x16xf32>, tensor<1x8x16xf32>)
101+
// CHECK-SAME: outs(%[[EMPTY_ADD]] : tensor<1x8x16xf32>)
102+
// CHECK: return %[[ADD]] : tensor<1x8x16xf32>
103+
104+
// -----
105+
106+
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
107+
#broadcast = affine_map<(d0, d1, d2) -> (d2)>
108+
func.func @ignore_non_ranked_tensor_types(%x: memref<1x8x16xf32>, %y: memref<16xf32>) {
109+
%empty = memref.alloc() : memref<1x8x16xf32>
110+
linalg.generic
111+
{ indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
112+
ins(%x, %y : memref<1x8x16xf32>, memref<16xf32>)
113+
outs(%empty : memref<1x8x16xf32>) {
114+
^bb0(%in: f32, %in2: f32, %out: f32):
115+
%add = arith.addf %in, %in2 : f32
116+
linalg.yield %add : f32
117+
}
118+
func.return
119+
}
120+
121+
// CHECK-LABEL: ignore_non_ranked_tensor_types
122+
// CHECK-SAME: %[[X:.+]]: memref<1x8x16xf32>
123+
// CHECK-SAME: %[[Y:.+]]: memref<16xf32>
124+
// CHECK: %[[EMPTY:.+]] = memref.alloc() : memref<1x8x16xf32>
125+
// CHECK: linalg.generic
126+
// CHECK-SAME: ins(%[[X]], %[[Y]] : memref<1x8x16xf32>, memref<16xf32>)
127+
// CHECK-SAME: outs(%[[EMPTY]] : memref<1x8x16xf32>)
128+
// CHECK: return

0 commit comments

Comments
 (0)