Skip to content

Commit

Permalink
AArch64: Implement vector shift and rotate evaluators
Browse files Browse the repository at this point in the history
This commit implements evaluators for vector shift and rotate opcodes.
Evaluators for the masked version of those opcodes are also added.

Signed-off-by: Akira Saitoh <saiaki@jp.ibm.com>
  • Loading branch information
Akira Saitoh committed Jun 1, 2023
1 parent 0509e46 commit 87130f9
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 32 deletions.
14 changes: 2 additions & 12 deletions compiler/aarch64/codegen/BinaryEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,18 +523,8 @@ OMR::ARM64::TreeEvaluator::lsubEvaluator(TR::Node *node, TR::CodeGenerator *cg)
return genericBinaryEvaluator(node, TR::InstOpCode::subx, TR::InstOpCode::subimmx, true, cg);
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *
inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL)
TR::Register *
OMR::ARM64::TreeEvaluator::inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper)
{
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
Expand Down
8 changes: 8 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,14 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I
case TR::vmor:
case TR::vmxor:
case TR::vmnot:
case TR::vshl:
case TR::vmshl:
case TR::vshr:
case TR::vmshr:
case TR::vushr:
case TR::vmushr:
case TR::vrol:
case TR::vmrol:
// Float/ Double are not supported
return (et == TR::Int8 || et == TR::Int16 || et == TR::Int32 || et == TR::Int64);
case TR::vload:
Expand Down
275 changes: 255 additions & 20 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3106,18 +3106,8 @@ OMR::ARM64::TreeEvaluator::vRegStoreEvaluator(TR::Node *node, TR::CodeGenerator
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
*
* @param[in] node: node
* @param[in] cg: CodeGenerator
* @param[in] op: binary opcode
* @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation
* @return vector register containing the result
*/
static TR::Register *
inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL)
TR::Register *
OMR::ARM64::TreeEvaluator::inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper)
{
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
Expand Down Expand Up @@ -4149,52 +4139,295 @@ OMR::ARM64::TreeEvaluator::vexpandEvaluator(TR::Node *node, TR::CodeGenerator *c
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

/**
* @brief Helper function for vector shift with immediate amount
*
* @param[in] node : node
* @param[in] cg : CodeGenerator
* @return the result register. Null is returned if the tree is not vector shift immediate.
*/
static TR::Register *
vectorShiftImmediateHelper(TR::Node *node, TR::CodeGenerator *cg)
{
TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation();
TR::DataType elementType = node->getDataType().getVectorElementType();
const bool isVectorShift = (vectorOp == TR::vshl) || (vectorOp == TR::vshr) || (vectorOp == TR::vushr);
const bool isVectorMaskedShift = (vectorOp == TR::vmshl) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr);
TR_ASSERT_FATAL_WITH_NODE(node, isVectorShift || isVectorMaskedShift, "opcode must be vector shift");
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");

TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
TR::Node *thirdChild = isVectorMaskedShift ? node->getThirdChild() : NULL;

if (!((secondChild->getOpCode().getVectorOperation() == TR::vsplats) &&
(secondChild->getRegister() == NULL) &&
secondChild->getFirstChild()->getOpCode().isLoadConst()))
{
return NULL;
}

TR::Node *constNode = secondChild->getFirstChild();

const bool isLeftShift = (vectorOp == TR::vshl) || (vectorOp == TR::vmshl);
const bool isLogicalRightShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr);
const int64_t value = constNode->getConstValue();
const int32_t elementSizeInBits = TR::DataType::getSize(elementType) * 8;
const int64_t start = isLeftShift ? 0 : 1;
const int64_t end = isLeftShift ? (elementSizeInBits - 1) : elementSizeInBits;
if ((value >= start) && (value <= end))
{
TR::Register *targetReg = cg->allocateRegister(TR_VRF);
TR::Register *lhsReg = cg->evaluate(firstChild);
TR::InstOpCode::Mnemonic op = static_cast<TR::InstOpCode::Mnemonic>(
(isLeftShift ? TR::InstOpCode::vshl16b :
isLogicalRightShift ? TR::InstOpCode::vushr16b : TR::InstOpCode::vsshr16b) +
(elementType - TR::Int8));
generateVectorShiftImmediateInstruction(cg, op, node, targetReg, lhsReg, value);
if (isVectorMaskedShift)
{
bool flipMask = false;
TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg);
/*
* BIT inserts each bit from the first source if the corresponding bit of the second source is 1.
* BIF inserts each bit from the first source if the corresponding bit of the second source is 0.
*/
generateTrg1Src2Instruction(cg, flipMask ? TR::InstOpCode::vbit16b : TR::InstOpCode::vbif16b, node, targetReg, lhsReg, maskReg);

cg->decReferenceCount(thirdChild);
}

node->setRegister(targetReg);
cg->decReferenceCount(firstChild);
cg->recursivelyDecReferenceCount(secondChild);

return targetReg;
}

return NULL;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vshlEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

TR::InstOpCode::Mnemonic shiftOp;
switch(node->getDataType().getVectorElementType())
{
case TR::Int8:
shiftOp = TR::InstOpCode::vsshl16b;
break;
case TR::Int16:
shiftOp = TR::InstOpCode::vsshl8h;
break;
case TR::Int32:
shiftOp = TR::InstOpCode::vsshl4s;
break;
case TR::Int64:
shiftOp = TR::InstOpCode::vsshl2d;
break;
default:
TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString());
return NULL;
}
return inlineVectorBinaryOp(node, cg, shiftOp);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmshlEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

TR::InstOpCode::Mnemonic shiftOp;
switch(node->getDataType().getVectorElementType())
{
case TR::Int8:
shiftOp = TR::InstOpCode::vsshl16b;
break;
case TR::Int16:
shiftOp = TR::InstOpCode::vsshl8h;
break;
case TR::Int32:
shiftOp = TR::InstOpCode::vsshl4s;
break;
case TR::Int64:
shiftOp = TR::InstOpCode::vsshl2d;
break;
default:
TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString());
return NULL;
}
return inlineVectorMaskedBinaryOp(node, cg, shiftOp);
}

/**
* @brief Helper function for vector right shift operation
*
* @param[in] node: node
* @param[in] resReg: the result register
* @param[in] lhsReg: the first argument register
* @param[in] rhsReg: the second argument register
* @param[in] cg: CodeGenerator
* @return the result register
*/
static TR::Register *
vectorRightShiftHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg)
{
TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation();
TR::DataType elementType = node->getDataType().getVectorElementType();
TR_ASSERT_FATAL_WITH_NODE(node, (vectorOp == TR::vshr) || (vectorOp == TR::vushr) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr),
"opcode must be vector right shift");
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");
const bool isLogicalShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr);
TR::InstOpCode::Mnemonic negOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vneg16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic shiftOp = static_cast<TR::InstOpCode::Mnemonic>(
(isLogicalShift ? TR::InstOpCode::vushl16b : TR::InstOpCode::vsshl16b) +
(elementType - TR::Int8));
generateTrg1Src1Instruction(cg, negOp, node, resultReg, rhsReg);
generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, resultReg);

return resultReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vshrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmshrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vushrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmushrEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Register *resultReg = vectorShiftImmediateHelper(node, cg);
if (resultReg != NULL)
{
return resultReg;
}

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper);
}

/**
* @brief Helper function for vector rotate operation
*
* @param[in] node: node
* @param[in] resReg: the result register
* @param[in] lhsReg: the first argument register
* @param[in] rhsReg: the second argument register
* @param[in] cg: CodeGenerator
* @return the result register
*/
static TR::Register *
vectorRotateHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg)
{
TR::DataType elementType = node->getDataType().getVectorElementType();
TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer");
TR::Register *tempReg = cg->allocateRegister(TR_VRF);
TR::InstOpCode::Mnemonic negOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vneg16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic shiftOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vushl16b + (elementType - TR::Int8));
TR::InstOpCode::Mnemonic subOp = static_cast<TR::InstOpCode::Mnemonic>(TR::InstOpCode::vsub16b + (elementType - TR::Int8));

if (elementType == TR::Int64)
{
/*
* AArch64 does not have instructions for loading arbitrary immediate 8bits value into a vector of 64-bit integer elements.
* Loading the value to a vector of 32-bit integer elements and using UXTL to extend elements to 64-bit.
*/
generateTrg1ImmInstruction(cg, TR::InstOpCode::vmovi4s, node, tempReg, 64);
generateVectorUXTLInstruction(cg, TR::Int32, node, tempReg, tempReg, false);
}
else
{
const int32_t sizeInBits = TR::DataType::getSize(elementType) * 8;
TR::InstOpCode::Mnemonic movOp = (elementType == TR::Int8) ? TR::InstOpCode::vmovi16b :
((elementType == TR::Int16) ? TR::InstOpCode::vmovi8h : TR::InstOpCode::vmovi4s);
generateTrg1ImmInstruction(cg, movOp, node, tempReg, sizeInBits);
}

/* (lhs << rhs) || (lhs >>> (sizeInBits - rhs)) */
generateTrg1Src2Instruction(cg, subOp, node, tempReg, rhsReg, tempReg);
generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, rhsReg);
generateTrg1Src2Instruction(cg, shiftOp, node, tempReg, lhsReg, tempReg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::vorr16b, node, resultReg, resultReg, tempReg);

cg->stopUsingRegister(tempReg);

return resultReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vrolEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmrolEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper);
}

TR::Register*
Expand All @@ -4207,6 +4440,7 @@ TR::Register*
OMR::ARM64::TreeEvaluator::vnotzEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);

}

TR::Register*
Expand All @@ -4219,6 +4453,7 @@ TR::Register*
OMR::ARM64::TreeEvaluator::vnolzEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);

}

TR::Register*
Expand Down
Loading

0 comments on commit 87130f9

Please sign in to comment.