Skip to content

[AArch64] Optimize when storing symmetry constants #93717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ struct AArch64LoadStoreOpt : public MachineFunctionPass {
// Find and merge an index ldr/st instruction into a base ld/st instruction.
bool tryToMergeIndexLdSt(MachineBasicBlock::iterator &MBBI, int Scale);

// Finds and collapses loads of symmetric constant value.
bool tryFoldSymmetryConstantLoad(MachineBasicBlock::iterator &I,
unsigned Limit);
MachineBasicBlock::iterator
doFoldSymmetryConstantLoad(MachineInstr &MI,
SmallVectorImpl<MachineBasicBlock::iterator> &MIs,
int UpperLoadIdx, int Accumulated);

bool optimizeBlock(MachineBasicBlock &MBB, bool EnableNarrowZeroStOpt);

bool runOnMachineFunction(MachineFunction &Fn) override;
Expand Down Expand Up @@ -2443,6 +2451,155 @@ AArch64LoadStoreOpt::findMatchingConstOffsetBackward(
return E;
}

static bool isSymmetricLoadCandidate(MachineInstr &MI, Register BaseReg) {
auto MatchBaseReg = [&](unsigned Count) {
for (unsigned I = 0; I < Count; I++) {
auto OpI = MI.getOperand(I);
if (OpI.isReg() && OpI.getReg() != BaseReg)
return false;
}
return true;
};

unsigned Opc = MI.getOpcode();
switch (Opc) {
default:
return false;
case AArch64::MOVZXi:
return MatchBaseReg(1);
case AArch64::MOVKXi:
return MatchBaseReg(2);
case AArch64::ORRXrs:
MachineOperand &Imm = MI.getOperand(3);
// Fourth operand of ORR must be 32 which mean
// 32bit symmetric constant load.
// ex) renamable $x8 = ORRXrs $x8, $x8, 32
if (MatchBaseReg(3) && Imm.isImm() && Imm.getImm() == 32)
return true;
}

return false;
}

MachineBasicBlock::iterator AArch64LoadStoreOpt::doFoldSymmetryConstantLoad(
MachineInstr &MI, SmallVectorImpl<MachineBasicBlock::iterator> &MIs,
int UpperLoadIdx, int Accumulated) {
MachineBasicBlock::iterator I = MI.getIterator();
MachineBasicBlock::iterator E = I->getParent()->end();
MachineBasicBlock::iterator NextI = next_nodbg(I, E);
MachineBasicBlock *MBB = MI.getParent();

if (!UpperLoadIdx) {
// ORR ensures that previous instructions load lower 32-bit constants.
// Remove ORR only.
(*MIs.begin())->eraseFromParent();
} else {
// We need to remove MOV for upper of 32bit because we know these instrs
// is part of symmetric constant.
int Index = 0;
for (auto MI = MIs.begin(); Index < UpperLoadIdx; ++MI, Index++) {
(*MI)->eraseFromParent();
}
}

Register BaseReg = getLdStRegOp(MI).getReg();
const MachineOperand MO = AArch64InstrInfo::getLdStBaseOp(MI);
Register DstRegW = TRI->getSubReg(BaseReg, AArch64::sub_32);
unsigned DstRegState = getRegState(MI.getOperand(0));
int Offset = AArch64InstrInfo::getLdStOffsetOp(MI).getImm();
BuildMI(*MBB, MI, MI.getDebugLoc(), TII->get(AArch64::STPWi))
.addReg(DstRegW, DstRegState)
.addReg(DstRegW, DstRegState)
.addReg(MO.getReg(), getRegState(MO))
.addImm(Offset * 2)
.setMemRefs(MI.memoperands())
.setMIFlags(MI.getFlags());
I->eraseFromParent();
return NextI;
}

bool AArch64LoadStoreOpt::tryFoldSymmetryConstantLoad(
MachineBasicBlock::iterator &I, unsigned Limit) {
MachineInstr &MI = *I;
if (MI.getOpcode() != AArch64::STRXui)
return false;

MachineBasicBlock::iterator MBBI = I;
MachineBasicBlock::iterator B = I->getParent()->begin();
if (MBBI == B)
return false;

TypeSize Scale(0U, false), Width(0U, false);
int64_t MinOffset, MaxOffset;
if (!AArch64InstrInfo::getMemOpInfo(AArch64::STPWi, Scale, Width, MinOffset,
MaxOffset))
return false;

// We replace the STRX instruction, which stores 64 bits, with the STPW
// instruction, which stores two consecutive 32 bits. Therefore, we compare
// the offset range with multiplied by two.
int Offset = AArch64InstrInfo::getLdStOffsetOp(MI).getImm();
if (Offset * 2 < MinOffset || Offset * 2 > MaxOffset)
return false;

Register BaseReg = getLdStRegOp(MI).getReg();
unsigned Count = 0, UpperLoadIdx = 0;
uint64_t Accumulated = 0, Mask = 0xFFFFUL;
bool hasORR = false, Found = false;
SmallVector<MachineBasicBlock::iterator> MIs;
ModifiedRegUnits.clear();
UsedRegUnits.clear();
do {
MBBI = prev_nodbg(MBBI, B);
MachineInstr &MI = *MBBI;
if (!MI.isTransient())
++Count;
if (!isSymmetricLoadCandidate(MI, BaseReg)) {
LiveRegUnits::accumulateUsedDefed(MI, ModifiedRegUnits, UsedRegUnits,
TRI);
if (!ModifiedRegUnits.available(BaseReg) ||
!UsedRegUnits.available(BaseReg))
return false;
continue;
}

unsigned Opc = MI.getOpcode();
if (Opc == AArch64::ORRXrs) {
hasORR = true;
MIs.push_back(MBBI);
continue;
}
unsigned ValueOrder = Opc == AArch64::MOVZXi ? 1 : 2;
MachineOperand Value = MI.getOperand(ValueOrder);
MachineOperand Shift = MI.getOperand(ValueOrder + 1);
if (!Value.isImm() || !Shift.isImm())
return false;

uint64_t IValue = Value.getImm();
uint64_t IShift = Shift.getImm();
uint64_t Adder = IValue << IShift;
MIs.push_back(MBBI);
if (Adder >> 32)
UpperLoadIdx = MIs.size();

Accumulated -= Accumulated & (Mask << IShift);
Accumulated += Adder;
if (Accumulated != 0 &&
(((Accumulated >> 32) == (Accumulated & 0xffffffffULL)) ||
(hasORR && (Accumulated >> 32 == 0)))) {
Found = true;
break;
}
} while (MBBI != B && Count < Limit);

if (Found) {
I = doFoldSymmetryConstantLoad(MI, MIs, UpperLoadIdx, Accumulated);
return true;
}

return false;
}

bool AArch64LoadStoreOpt::tryToPromoteLoadFromStore(
MachineBasicBlock::iterator &MBBI) {
MachineInstr &MI = *MBBI;
Expand Down Expand Up @@ -2753,6 +2910,27 @@ bool AArch64LoadStoreOpt::optimizeBlock(MachineBasicBlock &MBB,
++MBBI;
}

// We have an opportunity to optimize the `STRXui` instruction, which loads
// the same 32-bit value into a register twice. The `STPXi` instruction allows
// us to load a 32-bit value only once.
// Considering :
// renamable $x8 = MOVZXi 49370, 0
// renamable $x8 = MOVKXi $x8, 320, 16
// renamable $x8 = ORRXrs $x8, $x8, 32
// STRXui killed renamable $x8, killed renamable $x0, 0
// Transform :
// $w8 = MOVZWi 49370, 0
// $w8 = MOVKWi $w8, 320, 16
// STPWi killed renamable $w8, killed renamable $w8, killed renamable $x0, 0
for (MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
MBBI != E;) {
if (isMergeableLdStUpdate(*MBBI) &&
tryFoldSymmetryConstantLoad(MBBI, UpdateLimit))
Modified = true;
else
++MBBI;
}

return Modified;
}

Expand Down
180 changes: 180 additions & 0 deletions llvm/test/CodeGen/AArch64/movimm-expand-ldst.ll
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,183 @@ define i64 @testuu0xf555f555f555f555() {
; CHECK-NEXT: ret
ret i64 u0xf555f555f555f555
}

define void @test_store_0x1234567812345678(ptr %x) {
; CHECK-LABEL: test_store_0x1234567812345678:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22136 // =0x5678
; CHECK-NEXT: movk x8, #4660, lsl #16
; CHECK-NEXT: stp w8, w8, [x0]
; CHECK-NEXT: ret
store i64 u0x1234567812345678, ptr %x
ret void
}

define void @test_store_0xff3456ffff3456ff(ptr %x) {
; CHECK-LABEL: test_store_0xff3456ffff3456ff:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22271 // =0x56ff
; CHECK-NEXT: movk x8, #65332, lsl #16
; CHECK-NEXT: stp w8, w8, [x0]
; CHECK-NEXT: ret
store i64 u0xff3456ffff3456ff, ptr %x
ret void
}

define void @test_store_0x00345600345600(ptr %x) {
; CHECK-LABEL: test_store_0x00345600345600:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22016 // =0x5600
; CHECK-NEXT: movk x8, #52, lsl #16
; CHECK-NEXT: movk x8, #13398, lsl #32
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0x00345600345600, ptr %x
ret void
}

define void @test_store_0x5555555555555555(ptr %x) {
; CHECK-LABEL: test_store_0x5555555555555555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #6148914691236517205 // =0x5555555555555555
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0x5555555555555555, ptr %x
ret void
}

define void @test_store_0x5055555550555555(ptr %x) {
; CHECK-LABEL: test_store_0x5055555550555555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #6148914691236517205 // =0x5555555555555555
; CHECK-NEXT: and x8, x8, #0xf0fffffff0ffffff
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0x5055555550555555, ptr %x
ret void
}

define void @test_store_0x0000555555555555(ptr %x) {
; CHECK-LABEL: test_store_0x0000555555555555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #6148914691236517205 // =0x5555555555555555
; CHECK-NEXT: movk x8, #0, lsl #48
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0x0000555555555555, ptr %x
ret void
}

define void @test_store_0x0000555500005555(ptr %x) {
; CHECK-LABEL: test_store_0x0000555500005555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #21845 // =0x5555
; CHECK-NEXT: stp w8, w8, [x0]
; CHECK-NEXT: ret
store i64 u0x0000555500005555, ptr %x
ret void
}

define void @test_store_0x5555000055550000(ptr %x) {
; CHECK-LABEL: test_store_0x5555000055550000:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #1431633920 // =0x55550000
; CHECK-NEXT: stp w8, w8, [x0]
; CHECK-NEXT: ret
store i64 u0x5555000055550000, ptr %x
ret void
}

define void @test_store_u0xffff5555ffff5555(ptr %x) {
; CHECK-LABEL: test_store_u0xffff5555ffff5555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #-43691 // =0xffffffffffff5555
; CHECK-NEXT: movk x8, #21845, lsl #32
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0xffff5555ffff5555, ptr %x
ret void
}

define void @test_store_0x8888ffff8888ffff(ptr %x) {
; CHECK-LABEL: test_store_0x8888ffff8888ffff:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #-2004287489 // =0xffffffff8888ffff
; CHECK-NEXT: movk x8, #34952, lsl #48
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0x8888ffff8888ffff, ptr %x
ret void
}

define void @test_store_uu0xfffff555f555f555(ptr %x) {
; CHECK-LABEL: test_store_uu0xfffff555f555f555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #-2731 // =0xfffffffffffff555
; CHECK-NEXT: movk x8, #62805, lsl #16
; CHECK-NEXT: movk x8, #62805, lsl #32
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0xfffff555f555f555, ptr %x
ret void
}

define void @test_store_uu0xf555f555f555f555(ptr %x) {
; CHECK-LABEL: test_store_uu0xf555f555f555f555:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #6148914691236517205 // =0x5555555555555555
; CHECK-NEXT: orr x8, x8, #0xe001e001e001e001
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: ret
store i64 u0xf555f555f555f555, ptr %x
ret void
}

define void @test_store_0x1234567812345678_offset_range(ptr %x) {
; CHECK-LABEL: test_store_0x1234567812345678_offset_range:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22136 // =0x5678
; CHECK-NEXT: movk x8, #4660, lsl #16
; CHECK-NEXT: stp w8, w8, [x0, #32]
; CHECK-NEXT: ret
%g = getelementptr i64, ptr %x, i64 4
store i64 u0x1234567812345678, ptr %g
ret void
}

define void @test_store_0x1234567812345678_offset_min(ptr %x) {
; CHECK-LABEL: test_store_0x1234567812345678_offset_min:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22136 // =0x5678
; CHECK-NEXT: movk x8, #4660, lsl #16
; CHECK-NEXT: stp w8, w8, [x0]
; CHECK-NEXT: ret
%g = getelementptr i8, ptr %x, i32 0
store i64 u0x1234567812345678, ptr %g
ret void
}

define void @test_store_0x1234567812345678_offset_max(ptr %x) {
; CHECK-LABEL: test_store_0x1234567812345678_offset_max:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22136 // =0x5678
; CHECK-NEXT: movk x8, #4660, lsl #16
; CHECK-NEXT: stp w8, w8, [x0, #248]
; CHECK-NEXT: ret
%g = getelementptr i8, ptr %x, i32 248
store i64 u0x1234567812345678, ptr %g
ret void
}

define void @test_store_0x1234567812345678_offset_max_over(ptr %x) {
; CHECK-LABEL: test_store_0x1234567812345678_offset_max_over:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #22136 // =0x5678
; CHECK-NEXT: movk x8, #4660, lsl #16
; CHECK-NEXT: orr x8, x8, x8, lsl #32
; CHECK-NEXT: stur x8, [x0, #249]
; CHECK-NEXT: ret
%g = getelementptr i8, ptr %x, i32 249
store i64 u0x1234567812345678, ptr %g
ret void
}
Loading
Loading