Skip to content

Commit b7727f5

Browse files
Expose the FusedMultiplyAdd and MultiplyAddEstimate APIs on relevant vector and scalar types (#102181)
* Expose FusedMultiplyAdd and MultiplyAddEstimate on the scalar and vector types * Adding tests covering FusedMultiplyAdd and MultiplyAddEstimate for the vector types * Ensure TensorPrimitives uses the xplat APIs on .NET 9+ * Apply formatting patch * Fix an accidental change to GenericVectorTests * Ensure Arm64 passes fma operands in the correct order * Apply formatting patch * Ensure all the Arm64 code paths are spilling and swapping operands * Apply formatting patch * Don't pop the stack value twice
1 parent 2c39e05 commit b7727f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1424
-32
lines changed

src/coreclr/jit/compiler.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3226,6 +3226,13 @@ class Compiler
32263226
GenTree* gtNewSimdFloorNode(
32273227
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);
32283228

3229+
GenTree* gtNewSimdFmaNode(var_types type,
3230+
GenTree* op1,
3231+
GenTree* op2,
3232+
GenTree* op3,
3233+
CorInfoType simdBaseJitType,
3234+
unsigned simdSize);
3235+
32293236
GenTree* gtNewSimdGetElementNode(var_types type,
32303237
GenTree* op1,
32313238
GenTree* op2,

src/coreclr/jit/gentree.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23390,6 +23390,61 @@ GenTree* Compiler::gtNewSimdFloorNode(var_types type, GenTree* op1, CorInfoType
2339023390
return gtNewSimdHWIntrinsicNode(type, op1, intrinsic, simdBaseJitType, simdSize);
2339123391
}
2339223392

23393+
GenTree* Compiler::gtNewSimdFmaNode(
23394+
var_types type, GenTree* op1, GenTree* op2, GenTree* op3, CorInfoType simdBaseJitType, unsigned simdSize)
23395+
{
23396+
assert(varTypeIsSIMD(type));
23397+
assert(getSIMDTypeForSize(simdSize) == type);
23398+
23399+
assert(op1 != nullptr);
23400+
assert(op1->TypeIs(type));
23401+
23402+
assert(op2 != nullptr);
23403+
assert(op2->TypeIs(type));
23404+
23405+
assert(op3 != nullptr);
23406+
assert(op3->TypeIs(type));
23407+
23408+
var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
23409+
assert(varTypeIsFloating(simdBaseType));
23410+
23411+
NamedIntrinsic intrinsic = NI_Illegal;
23412+
23413+
#if defined(TARGET_XARCH)
23414+
if (simdSize == 64)
23415+
{
23416+
assert(compIsaSupportedDebugOnly(InstructionSet_AVX512F));
23417+
intrinsic = NI_AVX512F_FusedMultiplyAdd;
23418+
}
23419+
else
23420+
{
23421+
assert(compIsaSupportedDebugOnly(InstructionSet_FMA));
23422+
intrinsic = NI_FMA_MultiplyAdd;
23423+
}
23424+
#elif defined(TARGET_ARM64)
23425+
assert(IsBaselineSimdIsaSupportedDebugOnly());
23426+
23427+
if (simdBaseType == TYP_DOUBLE)
23428+
{
23429+
intrinsic = (simdSize == 8) ? NI_AdvSimd_FusedMultiplyAddScalar : NI_AdvSimd_Arm64_FusedMultiplyAdd;
23430+
}
23431+
else
23432+
{
23433+
intrinsic = NI_AdvSimd_FusedMultiplyAdd;
23434+
}
23435+
23436+
// AdvSimd.FusedMultiplyAdd expects (addend, left, right), while the APIs take (left, right, addend)
23437+
// We expect op1 and op2 to have already been spilled
23438+
23439+
std::swap(op1, op3);
23440+
#else
23441+
#error Unsupported platform
23442+
#endif // !TARGET_XARCH && !TARGET_ARM64
23443+
23444+
assert(intrinsic != NI_Illegal);
23445+
return gtNewSimdHWIntrinsicNode(type, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
23446+
}
23447+
2339323448
GenTree* Compiler::gtNewSimdGetElementNode(
2339423449
var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize)
2339523450
{

src/coreclr/jit/hwintrinsicarm64.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,26 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
13581358
break;
13591359
}
13601360

1361+
case NI_Vector64_FusedMultiplyAdd:
1362+
case NI_Vector128_FusedMultiplyAdd:
1363+
{
1364+
assert(sig->numArgs == 3);
1365+
assert(varTypeIsFloating(simdBaseType));
1366+
1367+
impSpillSideEffect(true, verCurrentState.esStackDepth -
1368+
3 DEBUGARG("Spilling op1 side effects for FusedMultiplyAdd"));
1369+
1370+
impSpillSideEffect(true, verCurrentState.esStackDepth -
1371+
2 DEBUGARG("Spilling op2 side effects for FusedMultiplyAdd"));
1372+
1373+
op3 = impSIMDPopStack();
1374+
op2 = impSIMDPopStack();
1375+
op1 = impSIMDPopStack();
1376+
1377+
retNode = gtNewSimdFmaNode(retType, op1, op2, op3, simdBaseJitType, simdSize);
1378+
break;
1379+
}
1380+
13611381
case NI_Vector64_get_AllBitsSet:
13621382
case NI_Vector128_get_AllBitsSet:
13631383
{
@@ -1702,6 +1722,31 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
17021722
break;
17031723
}
17041724

1725+
case NI_Vector64_MultiplyAddEstimate:
1726+
case NI_Vector128_MultiplyAddEstimate:
1727+
{
1728+
assert(sig->numArgs == 3);
1729+
assert(varTypeIsFloating(simdBaseType));
1730+
1731+
if (BlockNonDeterministicIntrinsics(mustExpand))
1732+
{
1733+
break;
1734+
}
1735+
1736+
impSpillSideEffect(true, verCurrentState.esStackDepth -
1737+
3 DEBUGARG("Spilling op1 side effects for MultiplyAddEstimate"));
1738+
1739+
impSpillSideEffect(true, verCurrentState.esStackDepth -
1740+
2 DEBUGARG("Spilling op2 side effects for MultiplyAddEstimate"));
1741+
1742+
op3 = impSIMDPopStack();
1743+
op2 = impSIMDPopStack();
1744+
op1 = impSIMDPopStack();
1745+
1746+
retNode = gtNewSimdFmaNode(retType, op1, op2, op3, simdBaseJitType, simdSize);
1747+
break;
1748+
}
1749+
17051750
case NI_Vector64_Narrow:
17061751
case NI_Vector128_Narrow:
17071752
{

src/coreclr/jit/hwintrinsiclistarm64.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ HARDWARE_INTRINSIC(Vector64, EqualsAll,
5656
HARDWARE_INTRINSIC(Vector64, EqualsAny, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
5757
HARDWARE_INTRINSIC(Vector64, ExtractMostSignificantBits, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
5858
HARDWARE_INTRINSIC(Vector64, Floor, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
59+
HARDWARE_INTRINSIC(Vector64, FusedMultiplyAdd, 8, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
5960
HARDWARE_INTRINSIC(Vector64, get_AllBitsSet, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
6061
HARDWARE_INTRINSIC(Vector64, get_Indices, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
6162
HARDWARE_INTRINSIC(Vector64, get_One, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
@@ -80,6 +81,7 @@ HARDWARE_INTRINSIC(Vector64, LoadUnsafe,
8081
HARDWARE_INTRINSIC(Vector64, Max, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
8182
HARDWARE_INTRINSIC(Vector64, Min, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
8283
HARDWARE_INTRINSIC(Vector64, Multiply, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
84+
HARDWARE_INTRINSIC(Vector64, MultiplyAddEstimate, 8, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
8385
HARDWARE_INTRINSIC(Vector64, Narrow, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
8486
HARDWARE_INTRINSIC(Vector64, Negate, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
8587
HARDWARE_INTRINSIC(Vector64, OnesComplement, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
@@ -169,6 +171,7 @@ HARDWARE_INTRINSIC(Vector128, EqualsAll,
169171
HARDWARE_INTRINSIC(Vector128, EqualsAny, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
170172
HARDWARE_INTRINSIC(Vector128, ExtractMostSignificantBits, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
171173
HARDWARE_INTRINSIC(Vector128, Floor, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
174+
HARDWARE_INTRINSIC(Vector128, FusedMultiplyAdd, 16, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
172175
HARDWARE_INTRINSIC(Vector128, get_AllBitsSet, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
173176
HARDWARE_INTRINSIC(Vector128, get_Indices, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
174177
HARDWARE_INTRINSIC(Vector128, get_One, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
@@ -195,6 +198,7 @@ HARDWARE_INTRINSIC(Vector128, LoadUnsafe,
195198
HARDWARE_INTRINSIC(Vector128, Max, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
196199
HARDWARE_INTRINSIC(Vector128, Min, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
197200
HARDWARE_INTRINSIC(Vector128, Multiply, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
201+
HARDWARE_INTRINSIC(Vector128, MultiplyAddEstimate, 16, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
198202
HARDWARE_INTRINSIC(Vector128, Narrow, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
199203
HARDWARE_INTRINSIC(Vector128, Negate, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
200204
HARDWARE_INTRINSIC(Vector128, OnesComplement, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)

0 commit comments

Comments
 (0)