Skip to content

Commit 0691c75

Browse files
authored
[arm64] JIT: Fold "A * B + C" to MADD/MSUB (#61037)
1 parent 5c8ea45 commit 0691c75

File tree

5 files changed

+98
-3
lines changed

5 files changed

+98
-3
lines changed

src/coreclr/jit/codegenlinear.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,9 +1608,14 @@ void CodeGen::genConsumeRegs(GenTree* tree)
16081608
}
16091609
#endif // FEATURE_HW_INTRINSICS
16101610
#endif // TARGET_XARCH
1611-
else if (tree->OperIs(GT_BITCAST))
1611+
else if (tree->OperIs(GT_BITCAST, GT_NEG))
16121612
{
1613-
genConsumeReg(tree->gtGetOp1());
1613+
genConsumeRegs(tree->gtGetOp1());
1614+
}
1615+
else if (tree->OperIs(GT_MUL))
1616+
{
1617+
genConsumeRegs(tree->gtGetOp1());
1618+
genConsumeRegs(tree->gtGetOp2());
16141619
}
16151620
else
16161621
{

src/coreclr/jit/emitarm64.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13597,6 +13597,44 @@ regNumber emitter::emitInsTernary(instruction ins, emitAttr attr, GenTree* dst,
1359713597
// src2 can only be a reg
1359813598
assert(!src2->isContained());
1359913599
}
13600+
else if ((src1->OperIs(GT_MUL) && src1->isContained()) || (src2->OperIs(GT_MUL) && src2->isContained()))
13601+
{
13602+
assert(ins == INS_add);
13603+
13604+
GenTree* mul;
13605+
GenTree* c;
13606+
if (src1->OperIs(GT_MUL))
13607+
{
13608+
mul = src1;
13609+
c = src2;
13610+
}
13611+
else
13612+
{
13613+
mul = src2;
13614+
c = src1;
13615+
}
13616+
13617+
GenTree* a = mul->gtGetOp1();
13618+
GenTree* b = mul->gtGetOp2();
13619+
13620+
assert(varTypeIsIntegral(mul) && !mul->gtOverflow());
13621+
13622+
bool msub = false;
13623+
if (a->OperIs(GT_NEG) && a->isContained())
13624+
{
13625+
a = a->gtGetOp1();
13626+
msub = true;
13627+
}
13628+
if (b->OperIs(GT_NEG) && b->isContained())
13629+
{
13630+
b = b->gtGetOp1();
13631+
msub = !msub; // it's either "a * -b" or "-a * -b" which is the same as "a * b"
13632+
}
13633+
13634+
emitIns_R_R_R_R(msub ? INS_msub : INS_madd, attr, dst->GetRegNum(), a->GetRegNum(), b->GetRegNum(),
13635+
c->GetRegNum());
13636+
return dst->GetRegNum();
13637+
}
1360013638
else // not floating point
1360113639
{
1360213640
// src2 can be immed or reg

src/coreclr/jit/lowerarmarch.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,47 @@ void Lowering::ContainCheckBinary(GenTreeOp* node)
15681568
{
15691569
// Check and make op2 contained (if it is a containable immediate)
15701570
CheckImmedAndMakeContained(node, node->gtOp2);
1571+
1572+
#ifdef TARGET_ARM64
1573+
// Find "a * b + c" or "c + a * b" in order to emit MADD/MSUB
1574+
if (comp->opts.OptimizationEnabled() && varTypeIsIntegral(node) && !node->isContained() && node->OperIs(GT_ADD) &&
1575+
!node->gtOverflow() && (node->gtGetOp1()->OperIs(GT_MUL) || node->gtGetOp2()->OperIs(GT_MUL)))
1576+
{
1577+
GenTree* mul;
1578+
GenTree* c;
1579+
if (node->gtGetOp1()->OperIs(GT_MUL))
1580+
{
1581+
mul = node->gtGetOp1();
1582+
c = node->gtGetOp2();
1583+
}
1584+
else
1585+
{
1586+
mul = node->gtGetOp2();
1587+
c = node->gtGetOp1();
1588+
}
1589+
1590+
GenTree* a = mul->gtGetOp1();
1591+
GenTree* b = mul->gtGetOp2();
1592+
1593+
if (!mul->isContained() && !mul->gtOverflow() && !a->isContained() && !b->isContained() && !c->isContained() &&
1594+
varTypeIsIntegral(mul))
1595+
{
1596+
if (a->OperIs(GT_NEG) && !a->gtGetOp1()->isContained() && !a->gtGetOp1()->IsRegOptional())
1597+
{
1598+
// "-a * b + c" to MSUB
1599+
MakeSrcContained(mul, a);
1600+
}
1601+
if (b->OperIs(GT_NEG) && !b->gtGetOp1()->isContained())
1602+
{
1603+
// "a * -b + c" to MSUB
1604+
MakeSrcContained(mul, b);
1605+
}
1606+
// If both 'a' and 'b' are GT_NEG - MADD will be emitted.
1607+
1608+
MakeSrcContained(node, mul);
1609+
}
1610+
}
1611+
#endif
15711612
}
15721613

15731614
//------------------------------------------------------------------------

src/coreclr/jit/lsraarm64.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ int LinearScan::BuildNode(GenTree* tree)
265265
// everything is made explicit by adding casts.
266266
assert(tree->gtGetOp1()->TypeGet() == tree->gtGetOp2()->TypeGet());
267267
}
268-
269268
FALLTHROUGH;
270269

271270
case GT_AND:

src/coreclr/jit/lsrabuild.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3071,6 +3071,18 @@ int LinearScan::BuildOperandUses(GenTree* node, regMaskTP candidates)
30713071
return 1;
30723072
}
30733073
#endif // FEATURE_HW_INTRINSICS
3074+
#ifdef TARGET_ARM64
3075+
if (node->OperIs(GT_MUL))
3076+
{
3077+
// Can be contained for MultiplyAdd on arm64
3078+
return BuildBinaryUses(node->AsOp(), candidates);
3079+
}
3080+
if (node->OperIs(GT_NEG))
3081+
{
3082+
// Can be contained for MultiplyAdd on arm64
3083+
return BuildOperandUses(node->gtGetOp1(), candidates);
3084+
}
3085+
#endif
30743086

30753087
return 0;
30763088
}

0 commit comments

Comments
 (0)