Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AArch64: Implement vector shift and rotate evaluators #7017

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions compiler/aarch64/codegen/BinaryEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,18 +523,8 @@ OMR::ARM64::TreeEvaluator::lsubEvaluator(TR::Node *node, TR::CodeGenerator *cg)
return genericBinaryEvaluator(node, TR::InstOpCode::subx, TR::InstOpCode::subimmx, true, cg);
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *
inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL)
TR::Register *
OMR::ARM64::TreeEvaluator::inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper)
{
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
Expand Down
8 changes: 8 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,14 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I
case TR::vmor:
case TR::vmxor:
case TR::vmnot:
case TR::vshl:
case TR::vmshl:
case TR::vshr:
case TR::vmshr:
case TR::vushr:
case TR::vmushr:
case TR::vrol:
case TR::vmrol:
// Float/ Double are not supported
return (et == TR::Int8 || et == TR::Int16 || et == TR::Int32 || et == TR::Int64);
case TR::vload:
Expand Down
273 changes: 253 additions & 20 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3106,18 +3106,8 @@ OMR::ARM64::TreeEvaluator::vRegStoreEvaluator(TR::Node *node, TR::CodeGenerator
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *
inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL)
TR::Register *
OMR::ARM64::TreeEvaluator::inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper)
{
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
Expand Down Expand Up @@ -4149,52 +4139,295 @@ OMR::ARM64::TreeEvaluator::vexpandEvaluator(TR::Node *node, TR::CodeGenerator *c
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

/**
* @brief Helper function for vector shift with immediate amount
*
* @param[in] node : node
* @param[in] cg : CodeGenerator
* @return the result register. Null is returned if the tree is not vector shift immediate.
*/
static TR::Register *
vectorShiftImmediateHelper(TR::Node *node, TR::CodeGenerator *cg)
{
TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation();
TR::DataType elementType = node->getDataType().getVectorElementType();
const bool isVectorShift = (vectorOp == TR::vshl) || (vectorOp == TR::vshr) || (vectorOp == TR::vushr);
const bool isVectorMaskedShift = (vectorOp == TR::vmshl) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr);
TR_ASSERT_FATAL_WITH_NODE(node, isVectorShift || isVectorMaskedShift, "opcode must be vector shift");
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");

TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
TR::Node *thirdChild = isVectorMaskedShift ? node->getThirdChild() : NULL;

if (!((secondChild->getOpCode().getVectorOperation() == TR::vsplats) &&
(secondChild->getRegister() == NULL) &&
secondChild->getFirstChild()->getOpCode().isLoadConst()))
{
return NULL;
}

TR::Node *constNode = secondChild->getFirstChild();

const bool isLeftShift = (vectorOp == TR::vshl) || (vectorOp == TR::vmshl);
const bool isLogicalRightShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr);
const int64_t value = constNode->getConstValue();
const int32_t elementSizeInBits = TR::DataType::getSize(elementType) * 8;
const int64_t start = isLeftShift ? 0 : 1;
const int64_t end = isLeftShift ? (elementSizeInBits - 1) : elementSizeInBits;
if ((value >= start) && (value <= end))
{
TR::Register *targetReg = cg->allocateRegister(TR_VRF);
TR::Register *lhsReg = cg->evaluate(firstChild);
TR::InstOpCode::Mnemonic op = static_cast<TR::InstOpCode::Mnemonic>(
(isLeftShift ? TR::InstOpCode::vshl16b :
isLogicalRightShift ? TR::InstOpCode::vushr16b : TR::InstOpCode::vsshr16b) +
(elementType - TR::Int8));
generateVectorShiftImmediateInstruction(cg, op, node, targetReg, lhsReg, value);
if (isVectorMaskedShift)
{
bool flipMask = false;
TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg);
/*
* BIT inserts each bit from the first source if the corresponding bit of the second source is 1.
* BIF inserts each bit from the first source if the corresponding bit of the second source is 0.
*/
generateTrg1Src2Instruction(cg, flipMask ? TR::InstOpCode::vbit16b : TR::InstOpCode::vbif16b, node, targetReg, lhsReg, maskReg);

cg->decReferenceCount(thirdChild);
}

node->setRegister(targetReg);
cg->decReferenceCount(firstChild);
cg->recursivelyDecReferenceCount(secondChild);

return targetReg;
}

return NULL;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vshlEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

TR::InstOpCode::Mnemonic shiftOp;
switch(node->getDataType().getVectorElementType())
{
case TR::Int8:
shiftOp = TR::InstOpCode::vsshl16b;
break;
case TR::Int16:
shiftOp = TR::InstOpCode::vsshl8h;
break;
case TR::Int32:
shiftOp = TR::InstOpCode::vsshl4s;
break;
case TR::Int64:
shiftOp = TR::InstOpCode::vsshl2d;
break;
default:
TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString());
return NULL;
}
return inlineVectorBinaryOp(node, cg, shiftOp);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmshlEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

TR::InstOpCode::Mnemonic shiftOp;
switch(node->getDataType().getVectorElementType())
{
case TR::Int8:
shiftOp = TR::InstOpCode::vsshl16b;
break;
case TR::Int16:
shiftOp = TR::InstOpCode::vsshl8h;
break;
case TR::Int32:
shiftOp = TR::InstOpCode::vsshl4s;
break;
case TR::Int64:
shiftOp = TR::InstOpCode::vsshl2d;
break;
default:
TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString());
return NULL;
}
return inlineVectorMaskedBinaryOp(node, cg, shiftOp);
}

/**
* @brief Helper function for vector right shift operation
*
* @param[in] node: node
* @param[in] resultReg: the result register
* @param[in] lhsReg: the first argument register
* @param[in] rhsReg: the second argument register
* @param[in] cg: CodeGenerator
* @return the result register
*/
static TR::Register *
vectorRightShiftHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg)
{
TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation();
TR::DataType elementType = node->getDataType().getVectorElementType();
TR_ASSERT_FATAL_WITH_NODE(node, (vectorOp == TR::vshr) || (vectorOp == TR::vushr) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr),
"opcode must be vector right shift");
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");
const bool isLogicalShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr);
TR::InstOpCode::Mnemonic negOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vneg16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic shiftOp = static_cast<TR::InstOpCode::Mnemonic>(
(isLogicalShift ? TR::InstOpCode::vushl16b : TR::InstOpCode::vsshl16b) +
(elementType - TR::Int8));
generateTrg1Src1Instruction(cg, negOp, node, resultReg, rhsReg);
generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, resultReg);

return resultReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vshrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmshrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vushrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmushrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

/**
* @brief Helper function for vector rotate operation
*
* @param[in] node: node
* @param[in] resultReg: the result register
* @param[in] lhsReg: the first argument register
* @param[in] rhsReg: the second argument register
* @param[in] cg: CodeGenerator
* @return the result register
*/
static TR::Register *
vectorRotateHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg)
{
TR::DataType elementType = node->getDataType().getVectorElementType();
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");
TR::Register *tempReg = cg->allocateRegister(TR_VRF);
TR::InstOpCode::Mnemonic negOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vneg16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic shiftOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vushl16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic subOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vsub16b + (elementType - TR::Int8));

if (elementType == TR::Int64)
{
/*
* AArch64 does not have instructions for loading arbitrary immediate 8bits value into a vector of 64-bit integer elements.
* Loading the value to a vector of 32-bit integer elements and using UXTL to extend elements to 64-bit.
*/
generateTrg1ImmInstruction(cg, TR::InstOpCode::vmovi4s, node, tempReg, 64);
generateVectorUXTLInstruction(cg, TR::Int32, node, tempReg, tempReg, false);
}
else
{
const int32_t sizeInBits = TR::DataType::getSize(elementType) * 8;
TR::InstOpCode::Mnemonic movOp = (elementType == TR::Int8) ? TR::InstOpCode::vmovi16b :
((elementType == TR::Int16) ? TR::InstOpCode::vmovi8h : TR::InstOpCode::vmovi4s);
generateTrg1ImmInstruction(cg, movOp, node, tempReg, sizeInBits);
}

/* (lhs << rhs) || (lhs >>> (sizeInBits - rhs)) */
generateTrg1Src2Instruction(cg, subOp, node, tempReg, rhsReg, tempReg);
generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, rhsReg);
generateTrg1Src2Instruction(cg, shiftOp, node, tempReg, lhsReg, tempReg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::vorr16b, node, resultReg, resultReg, tempReg);

cg->stopUsingRegister(tempReg);

return resultReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vrolEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmrolEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper);
}

TR::Register*
Expand Down
22 changes: 22 additions & 0 deletions compiler/aarch64/codegen/OMRTreeEvaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ class OMR_EXTENSIBLE TreeEvaluator: public OMR::TreeEvaluator
*/
static TR::Register *vmaxInt64Helper(TR::Node *node, TR::Register *resReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg);

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper function for generating instruction sequence for binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL);
/**
* @brief Helper function for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL);

static TR::Register *f2iuEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *f2luEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *f2buEvaluator(TR::Node *node, TR::CodeGenerator *cg);
Expand Down