@@ -1045,7 +1045,6 @@ struct ReorderElementwiseOpsOnBroadcast final
10451045 Type resultElemType = resultType.getElementType ();
10461046
10471047 // Get the type of the first non-constant operand
1048- // Operation *firstBroadcastOrSplat = nullptr;
10491048 Value splatSource;
10501049 for (Value operand : op->getOperands ()) {
10511050 Operation *definingOp = operand.getDefiningOp ();
@@ -1057,26 +1056,24 @@ struct ReorderElementwiseOpsOnBroadcast final
10571056 break ;
10581057 }
10591058 if (!splatSource)
1060- // TODO: why?
10611059 return failure ();
1062-
10631060 Type unbroadcastResultType =
10641061 cloneOrReplace (splatSource.getType (), resultElemType);
10651062
1066- Type lhsBcastOrSplatType = splatSource.getType ();
1067-
10681063 // Make sure that all operands are broadcast from identically-shaped types:
10691064 // * scalar (`vector.broadcast` + `vector.splat`), or
10701065 // * vector (`vector.broadcast`).
10711066 // Otherwise the re-ordering wouldn't be safe.
1072- if (!llvm::all_of (op->getOperands (), [lhsBcastOrSplatType ](Value val) {
1067+ if (!llvm::all_of (op->getOperands (), [splatSource ](Value val) {
10731068 if (auto source = getBroadcastLikeSource (val))
1074- return source.getType () == lhsBcastOrSplatType;
1069+ return haveSameShapeAndScaling (source.getType (),
1070+ splatSource.getType ());
10751071 SplatElementsAttr splatConst;
10761072 return matchPattern (val, m_Constant (&splatConst));
10771073 })) {
10781074 return rewriter.notifyMatchFailure (
1079- op, " not all operands are broadcasts from the sametype" );
1075+ op,
1076+ " not all operands are constants or broadcasts from the same type" );
10801077 }
10811078
10821079 // Collect the source values before broadcasting
0 commit comments