Skip to content

Commit 13c384c

Browse files
committed
[InferAlignment] Increase alignment in masked load / store instrinsics if known
Summary: The masked load / store LLVM intrinsics take an argument for the alignment. If the user is pessimistic about alignment they can provide a value of `1` for an unaligned load. This patch updates infer-alignment to increase the alignment value of the alignment argument if it is known greater than the provided one. Ignoring the gather / scatter versions for now since they contain many pointers.
1 parent ca14a8a commit 13c384c

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

llvm/lib/Transforms/Scalar/InferAlignment.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
#include "llvm/Transforms/Scalar/InferAlignment.h"
1515
#include "llvm/Analysis/AssumptionCache.h"
1616
#include "llvm/Analysis/ValueTracking.h"
17+
#include "llvm/IR/IRBuilder.h"
1718
#include "llvm/IR/Instructions.h"
19+
#include "llvm/IR/IntrinsicInst.h"
1820
#include "llvm/Support/KnownBits.h"
1921
#include "llvm/Transforms/Scalar.h"
2022
#include "llvm/Transforms/Utils/Local.h"
@@ -35,8 +37,39 @@ static bool tryToImproveAlign(
3537
return true;
3638
}
3739
}
38-
// TODO: Also handle memory intrinsics.
39-
return false;
40+
41+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
42+
if (!II)
43+
return false;
44+
45+
// TODO: Handle more memory intrinsics.
46+
switch (II->getIntrinsicID()) {
47+
case Intrinsic::masked_load:
48+
case Intrinsic::masked_store: {
49+
Value *PtrOp = II->getIntrinsicID() == Intrinsic::masked_load
50+
? II->getArgOperand(0)
51+
: II->getArgOperand(1);
52+
Value *AlignOp = II->getIntrinsicID() == Intrinsic::masked_load
53+
? II->getArgOperand(1)
54+
: II->getArgOperand(2);
55+
56+
Align OldAlign = cast<ConstantInt>(AlignOp)->getAlignValue();
57+
Align PrefAlign = getKnownAlignment(PtrOp, DL, II);
58+
Align NewAlign = Fn(PtrOp, OldAlign, PrefAlign);
59+
if (NewAlign <= OldAlign)
60+
return false;
61+
62+
Value *V = llvm::ConstantInt::get(llvm::Type::getInt32Ty(II->getContext()),
63+
NewAlign.value());
64+
if (II->getIntrinsicID() == Intrinsic::masked_load)
65+
II->setOperand(1, V);
66+
else
67+
II->setOperand(2, V);
68+
return true;
69+
}
70+
default:
71+
return false;
72+
}
4073
}
4174

4275
bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes=infer-alignment -S | FileCheck %s
3+
4+
define <2 x i32> @load(<2 x i1> %mask, ptr %ptr) {
5+
; CHECK-LABEL: define <2 x i32> @load(
6+
; CHECK-SAME: <2 x i1> [[MASK:%.*]], ptr [[PTR:%.*]]) {
7+
; CHECK-NEXT: [[ENTRY:.*:]]
8+
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR]], i64 64) ]
9+
; CHECK-NEXT: [[MASKED_LOAD:%.*]] = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr [[PTR]], i32 64, <2 x i1> [[MASK]], <2 x i32> poison)
10+
; CHECK-NEXT: ret <2 x i32> [[MASKED_LOAD]]
11+
;
12+
entry:
13+
call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
14+
%masked_load = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr %ptr, i32 1, <2 x i1> %mask, <2 x i32> poison)
15+
ret <2 x i32> %masked_load
16+
}
17+
18+
define void @store(<2 x i1> %mask, <2 x i32> %val, ptr %ptr) {
19+
; CHECK-LABEL: define void @store(
20+
; CHECK-SAME: <2 x i1> [[MASK:%.*]], <2 x i32> [[VAL:%.*]], ptr [[PTR:%.*]]) {
21+
; CHECK-NEXT: [[ENTRY:.*:]]
22+
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR]], i64 64) ]
23+
; CHECK-NEXT: tail call void @llvm.masked.store.v2i32.p0(<2 x i32> [[VAL]], ptr [[PTR]], i32 64, <2 x i1> [[MASK]])
24+
; CHECK-NEXT: ret void
25+
;
26+
entry:
27+
call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
28+
tail call void @llvm.masked.store.v2i32.p0(<2 x i32> %val, ptr %ptr, i32 1, <2 x i1> %mask)
29+
ret void
30+
}
31+
32+
declare void @llvm.assume(i1)
33+
declare <2 x i32> @llvm.masked.load.v2i32.p0(ptr, i32, <2 x i1>, <2 x i32>)
34+
declare void @llvm.masked.store.v2i32.p0(<2 x i32>, ptr, i32, <2 x i1>)

0 commit comments

Comments
 (0)