Skip to content

[RISCV] Move vmv.v.v peephole from SelectionDAG to RISCVVectorPeephole #100367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 14 additions & 70 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3664,32 +3664,6 @@ static bool IsVMerge(SDNode *N) {
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM;
}

static bool IsVMv(SDNode *N) {
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V;
}

static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
switch (LMUL) {
case RISCVII::LMUL_F8:
return RISCV::PseudoVMSET_M_B1;
case RISCVII::LMUL_F4:
return RISCV::PseudoVMSET_M_B2;
case RISCVII::LMUL_F2:
return RISCV::PseudoVMSET_M_B4;
case RISCVII::LMUL_1:
return RISCV::PseudoVMSET_M_B8;
case RISCVII::LMUL_2:
return RISCV::PseudoVMSET_M_B16;
case RISCVII::LMUL_4:
return RISCV::PseudoVMSET_M_B32;
case RISCVII::LMUL_8:
return RISCV::PseudoVMSET_M_B64;
case RISCVII::LMUL_RESERVED:
llvm_unreachable("Unexpected LMUL");
}
llvm_unreachable("Unknown VLMUL enum");
}

// Try to fold away VMERGE_VVM instructions into their true operands:
//
// %true = PseudoVADD_VV ...
Expand All @@ -3704,35 +3678,22 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
// mask is all ones.
//
// We can also fold a VMV_V_V into its true operand, since it is equivalent to a
// VMERGE_VVM with an all ones mask.
//
// The resulting VL is the minimum of the two VLs.
//
// The resulting policy is the effective policy the vmerge would have had,
// i.e. whether or not it's passthru operand was implicit-def.
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
SDValue Passthru, False, True, VL, Mask, Glue;
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
if (IsVMv(N)) {
Passthru = N->getOperand(0);
False = N->getOperand(0);
True = N->getOperand(1);
VL = N->getOperand(2);
// A vmv.v.v won't have a Mask or Glue, instead we'll construct an all-ones
// mask later below.
} else {
assert(IsVMerge(N));
Passthru = N->getOperand(0);
False = N->getOperand(1);
True = N->getOperand(2);
Mask = N->getOperand(3);
VL = N->getOperand(4);
// We always have a glue node for the mask at v0.
Glue = N->getOperand(N->getNumOperands() - 1);
}
assert(!Mask || cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
assert(!Glue || Glue.getValueType() == MVT::Glue);
assert(IsVMerge(N));
Passthru = N->getOperand(0);
False = N->getOperand(1);
True = N->getOperand(2);
Mask = N->getOperand(3);
VL = N->getOperand(4);
// We always have a glue node for the mask at v0.
Glue = N->getOperand(N->getNumOperands() - 1);
assert(cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
assert(Glue.getValueType() == MVT::Glue);

// If the EEW of True is different from vmerge's SEW, then we can't fold.
if (True.getSimpleValueType() != N->getSimpleValueType(0))
Expand Down Expand Up @@ -3780,7 +3741,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {

// If True is masked then the vmerge must have either the same mask or an all
// 1s mask, since we're going to keep the mask from True.
if (IsMasked && Mask) {
if (IsMasked) {
// FIXME: Support mask agnostic True instruction which would have an
// undef passthru operand.
SDValue TrueMask =
Expand Down Expand Up @@ -3810,11 +3771,9 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
SmallVector<const SDNode *, 4> LoopWorklist;
SmallPtrSet<const SDNode *, 16> Visited;
LoopWorklist.push_back(False.getNode());
if (Mask)
LoopWorklist.push_back(Mask.getNode());
LoopWorklist.push_back(Mask.getNode());
LoopWorklist.push_back(VL.getNode());
if (Glue)
LoopWorklist.push_back(Glue.getNode());
LoopWorklist.push_back(Glue.getNode());
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
return false;
}
Expand Down Expand Up @@ -3875,21 +3834,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
Glue = True->getOperand(True->getNumOperands() - 1);
assert(Glue.getValueType() == MVT::Glue);
}
// If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create
// an all-ones mask to use.
else if (IsVMv(N)) {
unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags;
unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags));
ElementCount EC = N->getValueType(0).getVectorElementCount();
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);

SDValue AllOnesMask =
SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
RISCV::V0, AllOnesMask, SDValue());
Mask = CurDAG->getRegister(RISCV::V0, MaskVT);
Glue = MaskCopy.getValue(1);
}

unsigned MaskedOpc = Info->MaskedPseudo;
#ifndef NDEBUG
Expand Down Expand Up @@ -3968,7 +3912,7 @@ bool RISCVDAGToDAGISel::doPeepholeMergeVVMFold() {
if (N->use_empty() || !N->isMachineOpcode())
continue;

if (IsVMerge(N) || IsVMv(N))
if (IsVMerge(N))
MadeChange |= performCombineVMergeAndVOps(N);
}
return MadeChange;
Expand Down
141 changes: 140 additions & 1 deletion llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
bool convertToWholeRegister(MachineInstr &MI) const;
bool convertToUnmasked(MachineInstr &MI) const;
bool convertVMergeToVMv(MachineInstr &MI) const;
bool foldVMV_V_V(MachineInstr &MI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a concern about the naming there. It is not consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed this as well. I think the camel case in convertVMergeToVMv came from trying to please clang-tidy. But VMV_V_V is more accurate to the actual instruction name. Do we have a preference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to decide... I don't have a preference, maybe just keep it, we already have a lof of exceptions in RISCVISelLowering now.


bool isAllOnesMask(const MachineInstr *MaskDef) const;
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
Expand Down Expand Up @@ -324,6 +325,143 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
return true;
}

/// Given two VL operands, returns the one known to be the smallest or nullptr
/// if unknown.
static const MachineOperand *getKnownMinVL(const MachineOperand *LHS,
const MachineOperand *RHS) {
if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() &&
LHS->getReg() == RHS->getReg())
return LHS;
if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel)
return RHS;
if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel)
return LHS;
if (!LHS->isImm() || !RHS->isImm())
return nullptr;
return LHS->getImm() <= RHS->getImm() ? LHS : RHS;
}

/// Check if it's safe to move From down to To, checking that no physical
/// registers are clobbered.
static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
assert(From.getParent() == To.getParent() && !From.hasImplicitDef());
SmallVector<Register> PhysUses;
for (const MachineOperand &MO : From.all_uses())
if (MO.getReg().isPhysical())
PhysUses.push_back(MO.getReg());
bool SawStore = false;
for (auto II = From.getIterator(); II != To.getIterator(); II++) {
for (Register PhysReg : PhysUses)
if (II->definesRegister(PhysReg, nullptr))
return false;
if (II->mayStore()) {
SawStore = true;
break;
}
}
return From.isSafeToMove(SawStore);
}

static unsigned getSEWLMULRatio(const MachineInstr &MI) {
RISCVII::VLMUL LMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
return RISCVVType::getSEWLMULRatio(1 << Log2SEW, LMUL);
}

/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
/// into it.
///
/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
///
/// ->
///
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, min(vl1, vl2), sew, policy
bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
return false;

MachineOperand &Passthru = MI.getOperand(1);

if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
return false;

MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
if (!Src || Src->hasUnmodeledSideEffects() ||
Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
!RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) ||
!RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
!RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags))
return false;

// Src needs to have the same VLMAX as MI
if (getSEWLMULRatio(MI) != getSEWLMULRatio(*Src))
return false;

// Src needs to have the same passthru as VMV_V_V
if (Src->getOperand(1).getReg() != RISCV::NoRegister &&
Src->getOperand(1).getReg() != Passthru.getReg())
return false;

// Because Src and MI have the same passthru, we can use either AVL as long as
// it's the smaller of the two.
//
// (src pt, ..., vl=5) x x x x x|. . .
// (vmv.v.v pt, src, vl=3) x x x|. . . . .
// ->
// (src pt, ..., vl=3) x x x|. . . . .
//
// (src pt, ..., vl=3) x x x|. . . . .
// (vmv.v.v pt, src, vl=6) x x x . . .|. .
// ->
// (src pt, ..., vl=3) x x x|. . . . .
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL);
if (!MinVL)
return false;

bool VLChanged = !MinVL->isIdenticalTo(SrcVL);
bool RaisesFPExceptions = MI.getDesc().mayRaiseFPException() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, shouldn't this be checking Src not MI? MI is the vmv.v.v which will never touch fp exceptions.

If so, this implies a significant testing gap here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's a typo. But it turns out we do have a test for fp exceptions on vmv.v.v and it was correctly bailing.

Everything still happened to work because the isSafeToMove check below also checks for mayRaiseFPException, but on the correct instruction. So I've just gone ahead and deleted this bogus check here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For anyone reading along, this comment is from a previous version of the code. Luke added back the fp exceptions check when making the movement conditional (as is required for correctness). I was initially confused by this comment, but it looks like the actual code is correct.

!MI.getFlag(MachineInstr::MIFlag::NoFPExcept);
bool ActiveElementsAffectResult = RISCVII::activeElementsAffectResult(
TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);

if (VLChanged && (ActiveElementsAffectResult || RaisesFPExceptions))
return false;

if (!isSafeToMove(*Src, MI))
return false;

// Move Src down to MI so it can access its passthru/VL, then replace all uses
// of MI with it.
Src->moveBefore(&MI);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given Src must dominate MI, why do we need to do this move at all? Can't we just rewrite the users of VMV_V_V to use the dominating definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We change Src's passthru to MI's passthru, and it may not be defined until after Src. So moving it to MI ensures its always available. This happens when Src's passthru is undef, which we consider equal to MI's passthru

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This movement only needs to happen in the case where the Passthru is changing. Can you move both the movement and the safety check to inside the if block just below? It would make the code a bit easier to follow.

Worth noting here is that we only need to move the Src if-and-only-if MI's psssthru is defined between the instructions. It may be illegal to move Src in some cases that the passthru is defined before src. You don't need to change that in this review, it's purely a potential missed optimization. Maybe leave a TODO though.

Copy link
Contributor Author

@lukel97 lukel97 Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it so we only perform the move if the passthru or VL changed: The VL can get shrunk to MI's VL, in which case it could be a register defined after Src.


if (Src->getOperand(1).getReg() != Passthru.getReg()) {
Src->getOperand(1).setReg(Passthru.getReg());
// If Src is masked then its passthru needs to be in VRNoV0.
if (Passthru.getReg() != RISCV::NoRegister)
MRI->constrainRegClass(Passthru.getReg(),
TII->getRegClass(Src->getDesc(), 1, TRI,
*Src->getParent()->getParent()));
}

if (MinVL->isImm())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through this code, I'm left with a question. If the Src instruction uses a different SEW than the vmv.v.v, why is it legal to reduce the VL without accounting for the different size of the elements? I can't find that check in the DAG version of this code either. Am I forgetting something here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes good catch. Looks like we're miscompiling this in the DAG version too, and I don't think it's legal to fold in the mask either. Fix incoming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed now in this PR, but by checking the VLMAXes are the same since we don't have access to the MVTs here.

SrcVL.ChangeToImmediate(MinVL->getImm());
else if (MinVL->isReg())
SrcVL.ChangeToRegister(MinVL->getReg(), false);

// Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if
// passthru is undef.
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()))
.setImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED);

MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
MI.eraseFromParent();
V0Defs.erase(&MI);

return true;
}

bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;
Expand Down Expand Up @@ -358,11 +496,12 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
}

for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
for (MachineInstr &MI : make_early_inc_range(MBB)) {
Changed |= convertToVLMAX(MI);
Changed |= convertToUnmasked(MI);
Changed |= convertToWholeRegister(MI);
Changed |= convertVMergeToVMv(MI);
Changed |= foldVMV_V_V(MI);
}
}

Expand Down
Loading