Skip to content

Commit

Permalink
Merge pull request #7017 from Akira1Saitoh/aarch64VectorRotate
Browse files Browse the repository at this point in the history
AArch64: Implement vector shift and rotate evaluators
  • Loading branch information
knn-k authored Jun 6, 2023
2 parents 283d187 + a3355e9 commit 044e8ce
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 32 deletions.
14 changes: 2 additions & 12 deletions compiler/aarch64/codegen/BinaryEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,18 +523,8 @@ OMR::ARM64::TreeEvaluator::lsubEvaluator(TR::Node *node, TR::CodeGenerator *cg)
return genericBinaryEvaluator(node, TR::InstOpCode::subx, TR::InstOpCode::subimmx, true, cg);
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *
inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL)
TR::Register *
OMR::ARM64::TreeEvaluator::inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper)
{
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
Expand Down
8 changes: 8 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,14 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I
case TR::vmor:
case TR::vmxor:
case TR::vmnot:
case TR::vshl:
case TR::vmshl:
case TR::vshr:
case TR::vmshr:
case TR::vushr:
case TR::vmushr:
case TR::vrol:
case TR::vmrol:
// Float/ Double are not supported
return (et == TR::Int8 || et == TR::Int16 || et == TR::Int32 || et == TR::Int64);
case TR::vload:
Expand Down
273 changes: 253 additions & 20 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3106,18 +3106,8 @@ OMR::ARM64::TreeEvaluator::vRegStoreEvaluator(TR::Node *node, TR::CodeGenerator
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *
inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL)
TR::Register *
OMR::ARM64::TreeEvaluator::inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper)
{
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
Expand Down Expand Up @@ -4149,52 +4139,295 @@ OMR::ARM64::TreeEvaluator::vexpandEvaluator(TR::Node *node, TR::CodeGenerator *c
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

/**
* @brief Helper function for vector shift with immediate amount
*
* @param[in] node : node
* @param[in] cg : CodeGenerator
* @return the result register. Null is returned if the tree is not vector shift immediate.
*/
static TR::Register *
vectorShiftImmediateHelper(TR::Node *node, TR::CodeGenerator *cg)
{
TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation();
TR::DataType elementType = node->getDataType().getVectorElementType();
const bool isVectorShift = (vectorOp == TR::vshl) || (vectorOp == TR::vshr) || (vectorOp == TR::vushr);
const bool isVectorMaskedShift = (vectorOp == TR::vmshl) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr);
TR_ASSERT_FATAL_WITH_NODE(node, isVectorShift || isVectorMaskedShift, "opcode must be vector shift");
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");

TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
TR::Node *thirdChild = isVectorMaskedShift ? node->getThirdChild() : NULL;

if (!((secondChild->getOpCode().getVectorOperation() == TR::vsplats) &&
(secondChild->getRegister() == NULL) &&
secondChild->getFirstChild()->getOpCode().isLoadConst()))
{
return NULL;
}

TR::Node *constNode = secondChild->getFirstChild();

const bool isLeftShift = (vectorOp == TR::vshl) || (vectorOp == TR::vmshl);
const bool isLogicalRightShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr);
const int64_t value = constNode->getConstValue();
const int32_t elementSizeInBits = TR::DataType::getSize(elementType) * 8;
const int64_t start = isLeftShift ? 0 : 1;
const int64_t end = isLeftShift ? (elementSizeInBits - 1) : elementSizeInBits;
if ((value >= start) && (value <= end))
{
TR::Register *targetReg = cg->allocateRegister(TR_VRF);
TR::Register *lhsReg = cg->evaluate(firstChild);
TR::InstOpCode::Mnemonic op = static_cast<TR::InstOpCode::Mnemonic>(
(isLeftShift ? TR::InstOpCode::vshl16b :
isLogicalRightShift ? TR::InstOpCode::vushr16b : TR::InstOpCode::vsshr16b) +
(elementType - TR::Int8));
generateVectorShiftImmediateInstruction(cg, op, node, targetReg, lhsReg, value);
if (isVectorMaskedShift)
{
bool flipMask = false;
TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg);
/*
* BIT inserts each bit from the first source if the corresponding bit of the second source is 1.
* BIF inserts each bit from the first source if the corresponding bit of the second source is 0.
*/
generateTrg1Src2Instruction(cg, flipMask ? TR::InstOpCode::vbit16b : TR::InstOpCode::vbif16b, node, targetReg, lhsReg, maskReg);

cg->decReferenceCount(thirdChild);
}

node->setRegister(targetReg);
cg->decReferenceCount(firstChild);
cg->recursivelyDecReferenceCount(secondChild);

return targetReg;
}

return NULL;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vshlEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

TR::InstOpCode::Mnemonic shiftOp;
switch(node->getDataType().getVectorElementType())
{
case TR::Int8:
shiftOp = TR::InstOpCode::vsshl16b;
break;
case TR::Int16:
shiftOp = TR::InstOpCode::vsshl8h;
break;
case TR::Int32:
shiftOp = TR::InstOpCode::vsshl4s;
break;
case TR::Int64:
shiftOp = TR::InstOpCode::vsshl2d;
break;
default:
TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString());
return NULL;
}
return inlineVectorBinaryOp(node, cg, shiftOp);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmshlEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

TR::InstOpCode::Mnemonic shiftOp;
switch(node->getDataType().getVectorElementType())
{
case TR::Int8:
shiftOp = TR::InstOpCode::vsshl16b;
break;
case TR::Int16:
shiftOp = TR::InstOpCode::vsshl8h;
break;
case TR::Int32:
shiftOp = TR::InstOpCode::vsshl4s;
break;
case TR::Int64:
shiftOp = TR::InstOpCode::vsshl2d;
break;
default:
TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString());
return NULL;
}
return inlineVectorMaskedBinaryOp(node, cg, shiftOp);
}

/**
* @brief Helper function for vector right shift operation
*
* @param[in] node: node
* @param[in] resultReg: the result register
* @param[in] lhsReg: the first argument register
* @param[in] rhsReg: the second argument register
* @param[in] cg: CodeGenerator
* @return the result register
*/
static TR::Register *
vectorRightShiftHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg)
{
TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation();
TR::DataType elementType = node->getDataType().getVectorElementType();
TR_ASSERT_FATAL_WITH_NODE(node, (vectorOp == TR::vshr) || (vectorOp == TR::vushr) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr),
"opcode must be vector right shift");
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");
const bool isLogicalShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr);
TR::InstOpCode::Mnemonic negOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vneg16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic shiftOp = static_cast<TR::InstOpCode::Mnemonic>(
(isLogicalShift ? TR::InstOpCode::vushl16b : TR::InstOpCode::vsshl16b) +
(elementType - TR::Int8));
generateTrg1Src1Instruction(cg, negOp, node, resultReg, rhsReg);
generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, resultReg);

return resultReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vshrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmshrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vushrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmushrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

/**
* @brief Helper function for vector rotate operation
*
* @param[in] node: node
* @param[in] resultReg: the result register
* @param[in] lhsReg: the first argument register
* @param[in] rhsReg: the second argument register
* @param[in] cg: CodeGenerator
* @return the result register
*/
static TR::Register *
vectorRotateHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg)
{
TR::DataType elementType = node->getDataType().getVectorElementType();
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");
TR::Register *tempReg = cg->allocateRegister(TR_VRF);
TR::InstOpCode::Mnemonic negOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vneg16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic shiftOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vushl16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic subOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vsub16b + (elementType - TR::Int8));

if (elementType == TR::Int64)
{
/*
* AArch64 does not have instructions for loading arbitrary immediate 8bits value into a vector of 64-bit integer elements.
* Loading the value to a vector of 32-bit integer elements and using UXTL to extend elements to 64-bit.
*/
generateTrg1ImmInstruction(cg, TR::InstOpCode::vmovi4s, node, tempReg, 64);
generateVectorUXTLInstruction(cg, TR::Int32, node, tempReg, tempReg, false);
}
else
{
const int32_t sizeInBits = TR::DataType::getSize(elementType) * 8;
TR::InstOpCode::Mnemonic movOp = (elementType == TR::Int8) ? TR::InstOpCode::vmovi16b :
((elementType == TR::Int16) ? TR::InstOpCode::vmovi8h : TR::InstOpCode::vmovi4s);
generateTrg1ImmInstruction(cg, movOp, node, tempReg, sizeInBits);
}

/* (lhs << rhs) || (lhs >>> (sizeInBits - rhs)) */
generateTrg1Src2Instruction(cg, subOp, node, tempReg, rhsReg, tempReg);
generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, rhsReg);
generateTrg1Src2Instruction(cg, shiftOp, node, tempReg, lhsReg, tempReg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::vorr16b, node, resultReg, resultReg, tempReg);

cg->stopUsingRegister(tempReg);

return resultReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vrolEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmrolEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper);
}

TR::Register*
Expand Down
22 changes: 22 additions & 0 deletions compiler/aarch64/codegen/OMRTreeEvaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ class OMR_EXTENSIBLE TreeEvaluator: public OMR::TreeEvaluator
*/
static TR::Register *vmaxInt64Helper(TR::Node *node, TR::Register *resReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg);

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper function for generating instruction sequence for binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL);
/**
* @brief Helper function for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL);

static TR::Register *f2iuEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *f2luEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *f2buEvaluator(TR::Node *node, TR::CodeGenerator *cg);
Expand Down

0 comments on commit 044e8ce

Please sign in to comment.