@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
186186  //  SIMD-specific configuration
187187  if  (Subtarget->hasSIMD128 ()) {
188188
189-     //  Combine partial.reduce.add before legalization gets confused.
190189    setTargetDAGCombine (ISD::INTRINSIC_WO_CHAIN);
191190
192191    //  Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
317316      setOperationAction (ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
318317      setOperationAction (ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
319318    }
319+ 
320+     //  Partial MLA reductions.
321+     for  (auto  Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
322+       setPartialReduceMLAAction (Op, MVT::v4i32, MVT::v16i8, Legal);
323+       setPartialReduceMLAAction (Op, MVT::v4i32, MVT::v8i16, Legal);
324+     }
320325  }
321326
322327  //  As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
416421  return  TargetLowering::getPointerMemTy (DL, AS);
417422}
418423
419- bool  WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic (
420-     const  IntrinsicInst *I) const  {
421-   if  (I->getIntrinsicID () != Intrinsic::vector_partial_reduce_add)
422-     return  true ;
423- 
424-   EVT VT = EVT::getEVT (I->getType ());
425-   if  (VT.getSizeInBits () > 128 )
426-     return  true ;
427- 
428-   auto  Op1 = I->getOperand (1 );
429- 
430-   if  (auto  *InputInst = dyn_cast<Instruction>(Op1)) {
431-     unsigned  Opcode = InstructionOpcodeToISD (InputInst->getOpcode ());
432-     if  (Opcode == ISD::MUL) {
433-       if  (isa<Instruction>(InputInst->getOperand (0 )) &&
434-           isa<Instruction>(InputInst->getOperand (1 ))) {
435-         //  dot only supports signed inputs but also support lowering unsigned.
436-         if  (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
437-             cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
438-           return  true ;
439- 
440-         EVT Op1VT = EVT::getEVT (Op1->getType ());
441-         if  (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
442-             ((VT.getVectorElementCount () * 2  ==
443-               Op1VT.getVectorElementCount ()) ||
444-              (VT.getVectorElementCount () * 4  == Op1VT.getVectorElementCount ())))
445-           return  false ;
446-       }
447-     } else  if  (ISD::isExtOpcode (Opcode)) {
448-       return  false ;
449-     }
450-   }
451-   return  true ;
452- }
453- 
454424TargetLowering::AtomicExpansionKind
455425WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR (AtomicRMWInst *AI) const  {
456426  //  We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
21132083                      MachinePointerInfo (SV));
21142084}
21152085
2116- //  Try to lower partial.reduce.add to a dot or fallback to a sequence with
2117- //  extmul and adds.
2118- SDValue performLowerPartialReduction (SDNode *N, SelectionDAG &DAG) {
2119-   assert (N->getOpcode () == ISD::INTRINSIC_WO_CHAIN);
2120-   if  (N->getConstantOperandVal (0 ) != Intrinsic::vector_partial_reduce_add)
2121-     return  SDValue ();
2122- 
2123-   assert (N->getValueType (0 ) == MVT::v4i32 && " can only support v4i32" 
2124-   SDLoc DL (N);
2125- 
2126-   SDValue Input = N->getOperand (2 );
2127-   if  (Input->getOpcode () == ISD::MUL) {
2128-     SDValue ExtendLHS = Input->getOperand (0 );
2129-     SDValue ExtendRHS = Input->getOperand (1 );
2130-     assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2131-             ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2132-            " expected widening mul or add" 
2133-     assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2134-            " expected binop to use the same extend for both operands" 
2135- 
2136-     SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2137-     SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2138-     bool  IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2139-     unsigned  LowOpc =
2140-         IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2141-     unsigned  HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2142-                                 : WebAssemblyISD::EXTEND_HIGH_U;
2143-     SDValue LowLHS;
2144-     SDValue LowRHS;
2145-     SDValue HighLHS;
2146-     SDValue HighRHS;
2147- 
2148-     auto  AssignInputs = [&](MVT VT) {
2149-       LowLHS = DAG.getNode (LowOpc, DL, VT, ExtendInLHS);
2150-       LowRHS = DAG.getNode (LowOpc, DL, VT, ExtendInRHS);
2151-       HighLHS = DAG.getNode (HighOpc, DL, VT, ExtendInLHS);
2152-       HighRHS = DAG.getNode (HighOpc, DL, VT, ExtendInRHS);
2153-     };
2154- 
2155-     if  (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2156-       if  (IsSigned) {
2157-         //  i32x4.dot_i16x8_s
2158-         SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2159-                                   ExtendInLHS, ExtendInRHS);
2160-         return  DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2161-       }
2162- 
2163-       //  (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2164-       MVT VT = MVT::v4i32;
2165-       AssignInputs (VT);
2166-       SDValue MulLow = DAG.getNode (ISD::MUL, DL, VT, LowLHS, LowRHS);
2167-       SDValue MulHigh = DAG.getNode (ISD::MUL, DL, VT, HighLHS, HighRHS);
2168-       SDValue Add = DAG.getNode (ISD::ADD, DL, VT, MulLow, MulHigh);
2169-       return  DAG.getNode (ISD::ADD, DL, VT, N->getOperand (1 ), Add);
2170-     } else  {
2171-       assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2172-              " expected v16i8 input types" 
2173-       AssignInputs (MVT::v8i16);
2174-       //  Lower to a wider tree, using twice the operations compared to above.
2175-       if  (IsSigned) {
2176-         //  Use two dots
2177-         SDValue DotLHS =
2178-             DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2179-         SDValue DotRHS =
2180-             DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2181-         SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2182-         return  DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2183-       }
2184- 
2185-       SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2186-       SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2187- 
2188-       SDValue AddLow = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189-                                    MVT::v4i32, MulLow);
2190-       SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2191-                                     MVT::v4i32, MulHigh);
2192-       SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2193-       return  DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2194-     }
2195-   } else  {
2196-     //  Accumulate the input using extadd_pairwise.
2197-     assert (ISD::isExtOpcode (Input.getOpcode ()) && " expected extend" 
2198-     bool  IsSigned = Input->getOpcode () == ISD::SIGN_EXTEND;
2199-     unsigned  PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2200-                                     : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2201-     SDValue ExtendIn = Input->getOperand (0 );
2202-     if  (ExtendIn->getValueType (0 ) == MVT::v8i16) {
2203-       SDValue Add = DAG.getNode (PairwiseOpc, DL, MVT::v4i32, ExtendIn);
2204-       return  DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2205-     }
2206- 
2207-     assert (ExtendIn->getValueType (0 ) == MVT::v16i8 &&
2208-            " expected v16i8 input types" 
2209-     SDValue Add =
2210-         DAG.getNode (PairwiseOpc, DL, MVT::v4i32,
2211-                     DAG.getNode (PairwiseOpc, DL, MVT::v8i16, ExtendIn));
2212-     return  DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2213-   }
2214- }
2215- 
22162086SDValue WebAssemblyTargetLowering::LowerIntrinsic (SDValue Op,
22172087                                                  SelectionDAG &DAG) const  {
22182088  MachineFunction &MF = DAG.getMachineFunction ();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
36833553    return  performVectorTruncZeroCombine (N, DCI);
36843554  case  ISD::TRUNCATE:
36853555    return  performTruncateCombine (N, DCI);
3686-   case  ISD::INTRINSIC_WO_CHAIN: {
3687-     if  (auto  AnyAllCombine = performAnyAllCombine (N, DCI.DAG ))
3688-       return  AnyAllCombine;
3689-     return  performLowerPartialReduction (N, DCI.DAG );
3690-   }
3556+   case  ISD::INTRINSIC_WO_CHAIN:
3557+     return  performAnyAllCombine (N, DCI.DAG );
36913558  case  ISD::MUL:
36923559    return  performMulCombine (N, DCI);
36933560  }
0 commit comments