From a3355e91d4de3517fd31bf51e467ef5c554d57fc Mon Sep 17 00:00:00 2001 From: Akira Saitoh Date: Wed, 10 May 2023 11:40:00 +0900 Subject: [PATCH] AArch64: Implement vector shift and rotate evaluators This commit implements evaluators for vector shift and rotate opcodes. Evaluators for the masked version of those opcodes are also added. Signed-off-by: Akira Saitoh --- compiler/aarch64/codegen/BinaryEvaluator.cpp | 14 +- compiler/aarch64/codegen/OMRCodeGenerator.cpp | 8 + compiler/aarch64/codegen/OMRTreeEvaluator.cpp | 273 ++++++++++++++++-- compiler/aarch64/codegen/OMRTreeEvaluator.hpp | 22 ++ 4 files changed, 285 insertions(+), 32 deletions(-) diff --git a/compiler/aarch64/codegen/BinaryEvaluator.cpp b/compiler/aarch64/codegen/BinaryEvaluator.cpp index a350766dee..b5eeb27f21 100644 --- a/compiler/aarch64/codegen/BinaryEvaluator.cpp +++ b/compiler/aarch64/codegen/BinaryEvaluator.cpp @@ -523,18 +523,8 @@ OMR::ARM64::TreeEvaluator::lsubEvaluator(TR::Node *node, TR::CodeGenerator *cg) return genericBinaryEvaluator(node, TR::InstOpCode::subx, TR::InstOpCode::subimmx, true, cg); } -typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg); -/** - * @brief Helper functions for generating instruction sequence for masked binary operations - * - * @param[in] node: node - * @param[in] cg: CodeGenerator - * @param[in] op: binary opcode - * @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation - * @return vector register containing the result - */ -static TR::Register * -inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL) +TR::Register * +OMR::ARM64::TreeEvaluator::inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper) { TR::Node *firstChild = node->getFirstChild(); TR::Node *secondChild = node->getSecondChild(); diff --git a/compiler/aarch64/codegen/OMRCodeGenerator.cpp b/compiler/aarch64/codegen/OMRCodeGenerator.cpp index 66674e4a39..7fc191512d 100644 --- a/compiler/aarch64/codegen/OMRCodeGenerator.cpp +++ b/compiler/aarch64/codegen/OMRCodeGenerator.cpp @@ -692,6 +692,14 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I case TR::vmor: case TR::vmxor: case TR::vmnot: + case TR::vshl: + case TR::vmshl: + case TR::vshr: + case TR::vmshr: + case TR::vushr: + case TR::vmushr: + case TR::vrol: + case TR::vmrol: // Float/ Double are not supported return (et == TR::Int8 || et == TR::Int16 || et == TR::Int32 || et == TR::Int64); case TR::vload: diff --git a/compiler/aarch64/codegen/OMRTreeEvaluator.cpp b/compiler/aarch64/codegen/OMRTreeEvaluator.cpp index 1c95f5e55b..1675993439 100644 --- a/compiler/aarch64/codegen/OMRTreeEvaluator.cpp +++ b/compiler/aarch64/codegen/OMRTreeEvaluator.cpp @@ -3106,18 +3106,8 @@ OMR::ARM64::TreeEvaluator::vRegStoreEvaluator(TR::Node *node, TR::CodeGenerator return TR::TreeEvaluator::unImpOpEvaluator(node, cg); } -typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg); -/** - * @brief Helper functions for generating instruction sequence for masked binary operations - * - * @param[in] node: node - * @param[in] cg: CodeGenerator - * @param[in] op: binary opcode - * @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation - * @return vector register containing the result - */ -static TR::Register * -inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL) +TR::Register * +OMR::ARM64::TreeEvaluator::inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper) { TR::Node *firstChild = node->getFirstChild(); TR::Node *secondChild = node->getSecondChild(); @@ -4149,52 +4139,295 @@ OMR::ARM64::TreeEvaluator::vexpandEvaluator(TR::Node *node, TR::CodeGenerator *c return TR::TreeEvaluator::unImpOpEvaluator(node, cg); } +/** + * @brief Helper function for vector shift with immediate amount + * + * @param[in] node : node + * @param[in] cg : CodeGenerator + * @return the result register. Null is returned if the tree is not vector shift immediate. + */ +static TR::Register * +vectorShiftImmediateHelper(TR::Node *node, TR::CodeGenerator *cg) + { + TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation(); + TR::DataType elementType = node->getDataType().getVectorElementType(); + const bool isVectorShift = (vectorOp == TR::vshl) || (vectorOp == TR::vshr) || (vectorOp == TR::vushr); + const bool isVectorMaskedShift = (vectorOp == TR::vmshl) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr); + TR_ASSERT_FATAL_WITH_NODE(node, isVectorShift || isVectorMaskedShift, "opcode must be vector shift"); + TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer"); + + TR::Node *firstChild = node->getFirstChild(); + TR::Node *secondChild = node->getSecondChild(); + TR::Node *thirdChild = isVectorMaskedShift ? node->getThirdChild() : NULL; + + if (!((secondChild->getOpCode().getVectorOperation() == TR::vsplats) && + (secondChild->getRegister() == NULL) && + secondChild->getFirstChild()->getOpCode().isLoadConst())) + { + return NULL; + } + + TR::Node *constNode = secondChild->getFirstChild(); + + const bool isLeftShift = (vectorOp == TR::vshl) || (vectorOp == TR::vmshl); + const bool isLogicalRightShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr); + const int64_t value = constNode->getConstValue(); + const int32_t elementSizeInBits = TR::DataType::getSize(elementType) * 8; + const int64_t start = isLeftShift ? 0 : 1; + const int64_t end = isLeftShift ? (elementSizeInBits - 1) : elementSizeInBits; + if ((value >= start) && (value <= end)) + { + TR::Register *targetReg = cg->allocateRegister(TR_VRF); + TR::Register *lhsReg = cg->evaluate(firstChild); + TR::InstOpCode::Mnemonic op = static_cast( + (isLeftShift ? TR::InstOpCode::vshl16b : + isLogicalRightShift ? TR::InstOpCode::vushr16b : TR::InstOpCode::vsshr16b) + + (elementType - TR::Int8)); + generateVectorShiftImmediateInstruction(cg, op, node, targetReg, lhsReg, value); + if (isVectorMaskedShift) + { + bool flipMask = false; + TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg); + /* + * BIT inserts each bit from the first source if the corresponding bit of the second source is 1. + * BIF inserts each bit from the first source if the corresponding bit of the second source is 0. + */ + generateTrg1Src2Instruction(cg, flipMask ? TR::InstOpCode::vbit16b : TR::InstOpCode::vbif16b, node, targetReg, lhsReg, maskReg); + + cg->decReferenceCount(thirdChild); + } + + node->setRegister(targetReg); + cg->decReferenceCount(firstChild); + cg->recursivelyDecReferenceCount(secondChild); + + return targetReg; + } + + return NULL; + } + TR::Register* OMR::ARM64::TreeEvaluator::vshlEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + TR::Register *resultReg = vectorShiftImmediateHelper(node, cg); + if (resultReg != NULL) + { + return resultReg; + } + + TR::InstOpCode::Mnemonic shiftOp; + switch(node->getDataType().getVectorElementType()) + { + case TR::Int8: + shiftOp = TR::InstOpCode::vsshl16b; + break; + case TR::Int16: + shiftOp = TR::InstOpCode::vsshl8h; + break; + case TR::Int32: + shiftOp = TR::InstOpCode::vsshl4s; + break; + case TR::Int64: + shiftOp = TR::InstOpCode::vsshl2d; + break; + default: + TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString()); + return NULL; + } + return inlineVectorBinaryOp(node, cg, shiftOp); } TR::Register* OMR::ARM64::TreeEvaluator::vmshlEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + TR::Register *resultReg = vectorShiftImmediateHelper(node, cg); + if (resultReg != NULL) + { + return resultReg; + } + + TR::InstOpCode::Mnemonic shiftOp; + switch(node->getDataType().getVectorElementType()) + { + case TR::Int8: + shiftOp = TR::InstOpCode::vsshl16b; + break; + case TR::Int16: + shiftOp = TR::InstOpCode::vsshl8h; + break; + case TR::Int32: + shiftOp = TR::InstOpCode::vsshl4s; + break; + case TR::Int64: + shiftOp = TR::InstOpCode::vsshl2d; + break; + default: + TR_ASSERT(false, "unrecognized vector type %s", node->getDataType().toString()); + return NULL; + } + return inlineVectorMaskedBinaryOp(node, cg, shiftOp); + } + +/** + * @brief Helper function for vector right shift operation + * + * @param[in] node: node + * @param[in] resultReg: the result register + * @param[in] lhsReg: the first argument register + * @param[in] rhsReg: the second argument register + * @param[in] cg: CodeGenerator + * @return the result register + */ +static TR::Register * +vectorRightShiftHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg) + { + TR::VectorOperation vectorOp = node->getOpCode().getVectorOperation(); + TR::DataType elementType = node->getDataType().getVectorElementType(); + TR_ASSERT_FATAL_WITH_NODE(node, (vectorOp == TR::vshr) || (vectorOp == TR::vushr) || (vectorOp == TR::vmshr) || (vectorOp == TR::vmushr), + "opcode must be vector right shift"); + TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer"); + const bool isLogicalShift = (vectorOp == TR::vushr) || (vectorOp == TR::vmushr); + TR::InstOpCode::Mnemonic negOp = static_cast(TR::InstOpCode::vneg16b + (elementType - TR::Int8)); + TR::InstOpCode::Mnemonic shiftOp = static_cast( + (isLogicalShift ? TR::InstOpCode::vushl16b : TR::InstOpCode::vsshl16b) + + (elementType - TR::Int8)); + generateTrg1Src1Instruction(cg, negOp, node, resultReg, rhsReg); + generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, resultReg); + + return resultReg; } TR::Register* OMR::ARM64::TreeEvaluator::vshrEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + TR::Register *resultReg = vectorShiftImmediateHelper(node, cg); + if (resultReg != NULL) + { + return resultReg; + } + + return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper); } TR::Register* OMR::ARM64::TreeEvaluator::vmshrEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + TR::Register *resultReg = vectorShiftImmediateHelper(node, cg); + if (resultReg != NULL) + { + return resultReg; + } + + return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper); } TR::Register* OMR::ARM64::TreeEvaluator::vushrEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + TR::Register *resultReg = vectorShiftImmediateHelper(node, cg); + if (resultReg != NULL) + { + return resultReg; + } + + return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper); } TR::Register* OMR::ARM64::TreeEvaluator::vmushrEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + TR::Register *resultReg = vectorShiftImmediateHelper(node, cg); + if (resultReg != NULL) + { + return resultReg; + } + + return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRightShiftHelper); + } + +/** + * @brief Helper function for vector rotate operation + * + * @param[in] node: node + * @param[in] resultReg: the result register + * @param[in] lhsReg: the first argument register + * @param[in] rhsReg: the second argument register + * @param[in] cg: CodeGenerator + * @return the result register + */ +static TR::Register * +vectorRotateHelper(TR::Node *node, TR::Register *resultReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg) + { + TR::DataType elementType = node->getDataType().getVectorElementType(); + TR_ASSERT_FATAL_WITH_NODE(node, (elementType >= TR::Int8) && (elementType <= TR::Int64), "elementType must be integer"); + TR::Register *tempReg = cg->allocateRegister(TR_VRF); + TR::InstOpCode::Mnemonic negOp = static_cast(TR::InstOpCode::vneg16b + (elementType - TR::Int8)); + TR::InstOpCode::Mnemonic shiftOp = static_cast(TR::InstOpCode::vushl16b + (elementType - TR::Int8)); + TR::InstOpCode::Mnemonic subOp = static_cast(TR::InstOpCode::vsub16b + (elementType - TR::Int8)); + + if (elementType == TR::Int64) + { + /* + * AArch64 does not have instructions for loading arbitrary immediate 8bits value into a vector of 64-bit integer elements. + * Loading the value to a vector of 32-bit integer elements and using UXTL to extend elements to 64-bit. + */ + generateTrg1ImmInstruction(cg, TR::InstOpCode::vmovi4s, node, tempReg, 64); + generateVectorUXTLInstruction(cg, TR::Int32, node, tempReg, tempReg, false); + } + else + { + const int32_t sizeInBits = TR::DataType::getSize(elementType) * 8; + TR::InstOpCode::Mnemonic movOp = (elementType == TR::Int8) ? TR::InstOpCode::vmovi16b : + ((elementType == TR::Int16) ? TR::InstOpCode::vmovi8h : TR::InstOpCode::vmovi4s); + generateTrg1ImmInstruction(cg, movOp, node, tempReg, sizeInBits); + } + + /* (lhs << rhs) || (lhs >>> (sizeInBits - rhs)) */ + generateTrg1Src2Instruction(cg, subOp, node, tempReg, rhsReg, tempReg); + generateTrg1Src2Instruction(cg, shiftOp, node, resultReg, lhsReg, rhsReg); + generateTrg1Src2Instruction(cg, shiftOp, node, tempReg, lhsReg, tempReg); + generateTrg1Src2Instruction(cg, TR::InstOpCode::vorr16b, node, resultReg, resultReg, tempReg); + + cg->stopUsingRegister(tempReg); + + return resultReg; } TR::Register* OMR::ARM64::TreeEvaluator::vrolEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + return inlineVectorBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper); } TR::Register* OMR::ARM64::TreeEvaluator::vmrolEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128, + "Only 128-bit vectors are supported %s", node->getDataType().toString()); + + return inlineVectorMaskedBinaryOp(node, cg, TR::InstOpCode::bad, vectorRotateHelper); } TR::Register* diff --git a/compiler/aarch64/codegen/OMRTreeEvaluator.hpp b/compiler/aarch64/codegen/OMRTreeEvaluator.hpp index 5c1745047d..6d4fcf6f6f 100644 --- a/compiler/aarch64/codegen/OMRTreeEvaluator.hpp +++ b/compiler/aarch64/codegen/OMRTreeEvaluator.hpp @@ -416,6 +416,28 @@ class OMR_EXTENSIBLE TreeEvaluator: public OMR::TreeEvaluator */ static TR::Register *vmaxInt64Helper(TR::Node *node, TR::Register *resReg, TR::Register *lhsReg, TR::Register *rhsReg, TR::CodeGenerator *cg); + typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg); + /** + * @brief Helper function for generating instruction sequence for binary operations + * + * @param[in] node: node + * @param[in] cg: CodeGenerator + * @param[in] op: binary opcode + * @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation + * @return vector register containing the result + */ + static TR::Register *inlineVectorBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL); + /** + * @brief Helper function for generating instruction sequence for masked binary operations + * + * @param[in] node: node + * @param[in] cg: CodeGenerator + * @param[in] op: binary opcode + * @param[in] evaluatorHelper: optional pointer to helper function which generates instruction stream for operation + * @return vector register containing the result + */ + static TR::Register *inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode::Mnemonic op, binaryEvaluatorHelper evaluatorHelper = NULL); + static TR::Register *f2iuEvaluator(TR::Node *node, TR::CodeGenerator *cg); static TR::Register *f2luEvaluator(TR::Node *node, TR::CodeGenerator *cg); static TR::Register *f2buEvaluator(TR::Node *node, TR::CodeGenerator *cg);