Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit a8d8aa3

Browse files
committed
Updating Sse.Shuffle to have a compiler fallback for indirect calls
1 parent e578b5d commit a8d8aa3

File tree

5 files changed

+123
-15
lines changed

5 files changed

+123
-15
lines changed

src/jit/hwintrinsiccodegenxarch.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ void CodeGen::genSSEIntrinsic(GenTreeHWIntrinsic* node)
517517
case NI_SSE_Shuffle:
518518
{
519519
GenTreeArgList* argList;
520-
ssize_t ival;
521520

522521
// Shuffle takes 3 operands, so op1 should be an arg list with two
523522
// additional node in the chain.
@@ -540,10 +539,67 @@ void CodeGen::genSSEIntrinsic(GenTreeHWIntrinsic* node)
540539

541540
argList = argList->Rest();
542541
op3 = argList->Current();
543-
ival = op3->AsIntConCommon()->IconValue();
544542
genConsumeRegs(op3);
545543

546-
emit->emitIns_SIMD_R_R_R_I(INS_shufps, targetReg, op1Reg, op2Reg, (int)ival, TYP_SIMD16);
544+
if (op3->IsCnsIntOrI())
545+
{
546+
ssize_t ival = op3->AsIntConCommon()->IconValue();
547+
emit->emitIns_SIMD_R_R_R_I(INS_shufps, targetReg, op1Reg, op2Reg, (int)ival, TYP_SIMD16);
548+
}
549+
else
550+
{
551+
// We emit a fallback case for the scenario when op3 is not a constant. This should normally
552+
// happen when the intrinsic is called indirectly, such as via Reflection. However, it can
553+
// also occur if the consumer calls it directly and just doesn't pass a constant value.
554+
555+
const unsigned jmpCount = 256;
556+
BasicBlock* jmpTable[jmpCount];
557+
558+
unsigned jmpTableBase = emit->emitBBTableDataGenBeg(jmpCount, true);
559+
unsigned jmpTableOffs = 0;
560+
561+
// Emit the jump table
562+
563+
JITDUMP("\n J_M%03u_DS%02u LABEL DWORD\n", Compiler::s_compMethodsCount, jmpTableBase);
564+
565+
for (unsigned i = 0; i < jmpCount; i++)
566+
{
567+
jmpTable[i] = genCreateTempLabel();
568+
JITDUMP(" DD L_M%03u_BB%02u\n", Compiler::s_compMethodsCount, jmpTable[i]->bbNum);
569+
emit->emitDataGenData(i, jmpTable[i]);
570+
}
571+
572+
emit->emitDataGenEnd();
573+
574+
// Compute and jump to the appropriate offset in the switch table
575+
576+
regNumber baseReg = node->ExtractTempReg(); // the start of the switch table
577+
regNumber offsReg = node->GetSingleTempReg(); // the offset into the switch table
578+
579+
emit->emitIns_R_C(INS_lea, emitTypeSize(TYP_I_IMPL), offsReg, compiler->eeFindJitDataOffs(jmpTableBase),
580+
0);
581+
582+
emit->emitIns_R_ARX(INS_mov, EA_4BYTE, offsReg, offsReg, op3->gtRegNum, 4, 0);
583+
emit->emitIns_R_L(INS_lea, EA_PTR_DSP_RELOC, compiler->fgFirstBB, baseReg);
584+
emit->emitIns_R_R(INS_add, EA_PTRSIZE, offsReg, baseReg);
585+
emit->emitIns_R(INS_i_jmp, emitTypeSize(TYP_I_IMPL), offsReg);
586+
587+
// Emit the switch table entries
588+
589+
BasicBlock* switchTableBeg = genCreateTempLabel();
590+
BasicBlock* switchTableEnd = genCreateTempLabel();
591+
592+
genDefineTempLabel(switchTableBeg);
593+
594+
for (unsigned i = 0; i < jmpCount; i++)
595+
{
596+
genDefineTempLabel(jmpTable[i]);
597+
emit->emitIns_SIMD_R_R_R_I(INS_shufps, targetReg, op1Reg, op2Reg, i, TYP_SIMD16);
598+
emit->emitIns_J(INS_jmp, switchTableEnd);
599+
}
600+
601+
genDefineTempLabel(switchTableEnd);
602+
}
547603
break;
548604
}
549605

src/jit/hwintrinsicxarch.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ GenTree* Compiler::impSSEIntrinsic(NamedIntrinsic intrinsic,
489489
GenTree* left = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op4, op3, NI_SSE_UnpackLow, TYP_FLOAT, 16);
490490
GenTree* right = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op2, op1, NI_SSE_UnpackLow, TYP_FLOAT, 16);
491491
GenTree* control = gtNewIconNode(68, TYP_UBYTE);
492-
492+
493493
retNode = gtNewSimdHWIntrinsicNode(TYP_SIMD16, left, right, control, NI_SSE_Shuffle, TYP_FLOAT, 16);
494494
break;
495495
}
@@ -498,10 +498,25 @@ GenTree* Compiler::impSSEIntrinsic(NamedIntrinsic intrinsic,
498498
{
499499
assert(sig->numArgs == 3);
500500
assert(getBaseTypeOfSIMDType(sig->retTypeSigClass) == TYP_FLOAT);
501-
op3 = impPopStack().val;
502-
op2 = impSIMDPopStack(TYP_SIMD16);
503-
op1 = impSIMDPopStack(TYP_SIMD16);
504-
retNode = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op2, op3, intrinsic, TYP_FLOAT, 16);
501+
502+
op3 = impStackTop().val;
503+
504+
if (op3->IsCnsIntOrI() || mustExpand)
505+
{
506+
impPopStack(); // Pop the value we peeked at
507+
op2 = impSIMDPopStack(TYP_SIMD16);
508+
op1 = impSIMDPopStack(TYP_SIMD16);
509+
retNode = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op2, op3, intrinsic, TYP_FLOAT, 16);
510+
}
511+
else
512+
{
513+
// When op3 is not a constant and we are not being forced to expand, we need to
514+
// return nullptr so a GT_CALL to the intrinsic method is emitted instead. The
515+
// intrinsic method is recursive and will be forced to expand, at which point
516+
// we emit some less efficient fallback code.
517+
518+
return nullptr;
519+
}
505520
break;
506521
}
507522

@@ -586,7 +601,8 @@ GenTree* Compiler::impSSEIntrinsic(NamedIntrinsic intrinsic,
586601
CORINFO_CLASS_HANDLE argClass;
587602

588603
CORINFO_ARG_LIST_HANDLE argLst = info.compCompHnd->getArgNext(sig->args);
589-
CorInfoType corType = strip(info.compCompHnd->getArgType(sig, argLst, &argClass)); // type of the second argument
604+
CorInfoType corType =
605+
strip(info.compCompHnd->getArgType(sig, argLst, &argClass)); // type of the second argument
590606

591607
if (varTypeIsLong(corType))
592608
{

src/jit/lowerxarch.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,10 +2307,16 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
23072307
switch (node->gtHWIntrinsicId)
23082308
{
23092309
case NI_SSE_Shuffle:
2310-
// third operand is an integer constant and marked as contained.
2311-
assert(op1->AsArgList()->Rest()->Rest()->Current()->IsCnsIntOrI());
2312-
MakeSrcContained(node, op1->AsArgList()->Rest()->Rest()->Current());
2310+
{
2311+
assert(op1->OperIsList());
2312+
GenTree* op3 = op1->AsArgList()->Rest()->Rest()->Current();
2313+
2314+
if (op3->IsCnsIntOrI())
2315+
{
2316+
MakeSrcContained(node, op3);
2317+
}
23132318
break;
2319+
}
23142320

23152321
default:
23162322
assert((intrinsicID > NI_HW_INTRINSIC_START) && (intrinsicID < NI_HW_INTRINSIC_END));

src/jit/lsraxarch.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,10 +2550,24 @@ void LinearScan::TreeNodeInfoInitHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree,
25502550
break;
25512551

25522552
case NI_SSE_Shuffle:
2553-
// Third operand should be a contained integer constant
2554-
assert(info->srcCount == 2);
2555-
assert(op1->AsArgList()->Rest()->Rest()->Current()->isContainedIntOrIImmed());
2553+
{
2554+
assert(op1->OperIsList());
2555+
GenTree* op3 = op1->AsArgList()->Rest()->Rest()->Current();
2556+
2557+
if (!op3->isContainedIntOrIImmed())
2558+
{
2559+
assert(!op3->IsCnsIntOrI());
2560+
2561+
// We need two extra reg when op3 isn't a constant so
2562+
// the offset into the jump table for the fallback path
2563+
// can be computed.
2564+
2565+
info->internalIntCount = 2;
2566+
info->setInternalCandidates(this, allRegs(TYP_INT));
2567+
break;
2568+
}
25562569
break;
2570+
}
25572571

25582572
case NI_SSE_ConvertToSingle:
25592573
case NI_SSE_StaticCast:

tests/src/JIT/HardwareIntrinsics/X86/Sse/Shuffle.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ static unsafe int Main(string[] args)
9191
Console.WriteLine();
9292
testResult = Fail;
9393
}
94+
95+
// XYZW
96+
vf3 = (Vector128<float>)typeof(Sse).GetMethod(nameof(Sse.Shuffle), new Type[] { vf1.GetType(), vf2.GetType(), typeof(byte) }).Invoke(null, new object[] { vf1, vf2, (byte)(27) });
97+
Unsafe.Write(floatTable.outArrayPtr, vf3);
98+
99+
if (!floatTable.CheckResult((x, y, z) => (z[0] == x[3]) && (z[1] == x[2]) &&
100+
(z[2] == y[1]) && (z[3] == y[0])))
101+
{
102+
Console.WriteLine("SSE Shuffle failed on float:");
103+
foreach (var item in floatTable.outArray)
104+
{
105+
Console.Write(item + ", ");
106+
}
107+
Console.WriteLine();
108+
testResult = Fail;
109+
}
94110
}
95111
}
96112

0 commit comments

Comments
 (0)