@@ -166,7 +166,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
166
166
// out which operand can supply that runtime-value (tensor.dim).
167
167
// Leaving it as a future TODO.
168
168
if (llvm::any_of (op->getOpOperands (), [](OpOperand &oper) {
169
- auto opType = cast<RankedTensorType>(oper.get ().getType ());
169
+ // Allow scalar values as these can be broadcasted on the input.
170
+ if (oper.get ().getType ().isIntOrFloat ())
171
+ return false ;
172
+ // If any of the operands are not a RankedTensorType, then we should
173
+ // return early. The pattern has been built with RankedTensors in mind.
174
+ if (!isa<RankedTensorType>(oper.get ().getType ()))
175
+ return true ;
176
+ auto opType = cast<ShapedType>(oper.get ().getType ());
170
177
return ShapedType::isDynamicShape (opType.getShape ());
171
178
}))
172
179
return failure ();
@@ -181,10 +188,27 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
181
188
// Walk over each input operand and unfold if it is transposed, broadcast
182
189
// or mix of two via operand's affine-map.
183
190
for (int64_t i = 0 ; i < op.getNumDpsInputs (); ++i) {
184
- auto &map = newMap[i];
185
- auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType ());
186
- auto elType = inputRTType.getElementType ();
191
+ auto inputType = newInitValues[i].getType ();
192
+ SmallVector<int64_t > inputShape =
193
+ llvm::TypeSwitch<Type, SmallVector<int64_t >>(inputType)
194
+ .Case ([](RankedTensorType tensor) { return tensor.getShape (); })
195
+ .Case ([](FloatType scalar) { return SmallVector<int64_t >({1 }); })
196
+ .Case ([](IntegerType scalar) { return SmallVector<int64_t >({1 }); })
197
+ .Default ([](Type) { return SmallVector<int64_t >(); });
198
+
199
+ Type elType = llvm::TypeSwitch<Type, Type>(inputType)
200
+ .Case ([](RankedTensorType tensor) {
201
+ return tensor.getElementType ();
202
+ })
203
+ .Case ([](FloatType scalar) { return scalar; })
204
+ .Case ([](IntegerType scalar) { return scalar; })
205
+ .Default ([](Type) { return Type (); });
206
+
207
+ // If we were not able to result the information skip.
208
+ if (inputShape.empty () || !elType)
209
+ continue ;
187
210
211
+ auto &map = newMap[i];
188
212
// / Nothing to do if map is already an identity.
189
213
if (map.isIdentity ())
190
214
continue ;
@@ -197,7 +221,7 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
197
221
// / rule: dim(result, i) = dim(input, permutation[i])
198
222
SmallVector<int64_t > transposedShape (map.getNumResults ());
199
223
for (int64_t i = 0 ; i < map.getNumResults (); ++i)
200
- transposedShape[i] = inputRTType. getShape () [permutation[i]];
224
+ transposedShape[i] = inputShape [permutation[i]];
201
225
202
226
Value emptyTensor =
203
227
rewriter.create <tensor::EmptyOp>(loc, transposedShape, elType);
@@ -211,13 +235,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
211
235
// Does it require broadcast?
212
236
if (!broadcastedDims.empty ()) {
213
237
assert (broadcastedDims.size () && " should have non size broadcast" );
214
- Value emptyTensor = rewriter. create <tensor::EmptyOp>(
215
- loc, outputShape, inputRTType. getElementType () );
238
+ Value emptyTensor =
239
+ rewriter. create <tensor::EmptyOp>( loc, outputShape, elType );
216
240
217
- auto broadcastOp = rewriter.create <linalg::BroadcastOp>(
218
- loc, newInitValues[i], emptyTensor, broadcastedDims);
241
+ Value source = newInitValues[i];
242
+ Value result;
243
+ // If a scalar is being broadcasted we can simply use a fill operation.
244
+ if (source.getType ().isIntOrFloat ()) {
245
+ result = rewriter.create <linalg::FillOp>(loc, source, emptyTensor)
246
+ ->getResult (0 );
247
+ } else {
248
+ result = rewriter
249
+ .create <linalg::BroadcastOp>(loc, source, emptyTensor,
250
+ broadcastedDims)
251
+ ->getResult (0 );
252
+ }
219
253
220
- newInitValues[i] = broadcastOp-> getResult ( 0 ) ;
254
+ newInitValues[i] = result ;
221
255
isChanged = true ;
222
256
}
223
257
newMap[i] = rewriter.getMultiDimIdentityMap (map.getNumDims ());
0 commit comments