Skip to content

Commit a70843a

Browse files
committed
[Matrix] Propagate shape information through (f)abs insts
1 parent c647c58 commit a70843a

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,15 @@ static bool isUniformShape(Value *V) {
229229
if (!I)
230230
return true;
231231

232+
if (auto *II = dyn_cast<IntrinsicInst>(V))
233+
switch (II->getIntrinsicID()) {
234+
case Intrinsic::abs:
235+
case Intrinsic::fabs:
236+
return true;
237+
default:
238+
return false;
239+
}
240+
232241
switch (I->getOpcode()) {
233242
case Instruction::FAdd:
234243
case Instruction::FSub:
@@ -625,7 +634,7 @@ class LowerMatrixIntrinsics {
625634
case Intrinsic::matrix_column_major_store:
626635
return true;
627636
default:
628-
return false;
637+
return isUniformShape(II);
629638
}
630639
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
631640
}
@@ -1131,6 +1140,9 @@ class LowerMatrixIntrinsics {
11311140
case Intrinsic::matrix_column_major_store:
11321141
LowerColumnMajorStore(Inst);
11331142
break;
1143+
case Intrinsic::abs:
1144+
case Intrinsic::fabs:
1145+
return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
11341146
default:
11351147
return false;
11361148
}
@@ -2223,6 +2235,44 @@ class LowerMatrixIntrinsics {
22232235
return true;
22242236
}
22252237

2238+
/// Lower uniform shape intrinsics, if shape information is available.
2239+
bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
2240+
auto I = ShapeMap.find(Inst);
2241+
if (I == ShapeMap.end())
2242+
return false;
2243+
2244+
IRBuilder<> Builder(Inst);
2245+
ShapeInfo &Shape = I->second;
2246+
2247+
MatrixTy Result;
2248+
2249+
switch (Inst->getIntrinsicID()) {
2250+
case Intrinsic::abs:
2251+
case Intrinsic::fabs: {
2252+
Value *Op = Inst->getOperand(0);
2253+
2254+
MatrixTy M = getMatrix(Op, Shape, Builder);
2255+
2256+
Builder.setFastMathFlags(getFastMathFlags(Inst));
2257+
2258+
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2259+
switch (Inst->getIntrinsicID()) {
2260+
case Intrinsic::abs:
2261+
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, M.getVector(I), Inst->getOperand(1)));
2262+
break;
2263+
case Intrinsic::fabs:
2264+
Result.addVector(Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), M.getVector(I)));
2265+
break;
2266+
}
2267+
2268+
finalizeLowering(Inst, Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * Result.getNumVectors()), Builder);
2269+
return true;
2270+
}
2271+
default:
2272+
llvm_unreachable("unexpected intrinsic");
2273+
}
2274+
}
2275+
22262276
/// Helper to linearize a matrix expression tree into a string. Currently
22272277
/// matrix expressions are linarized by starting at an expression leaf and
22282278
/// linearizing bottom up.

llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,46 @@ define void @fdiv_2x2(ptr %num, ptr %denom, ptr %out) {
2424
store <4 x double> %divtt, ptr %out
2525
ret void
2626
}
27+
28+
define void @fabs_2x2f64(ptr %in, ptr %out) {
29+
; CHECK-LABEL: @fabs_2x2f64(
30+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
31+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
32+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
33+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[COL_LOAD]])
34+
; CHECK-NEXT: [[TMP2:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[COL_LOAD1]])
35+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
36+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
37+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
38+
; CHECK-NEXT: ret void
39+
;
40+
%load = load <4 x double>, ptr %in
41+
%fabs = call <4 x double> @llvm.fabs.v4f64(<4 x double> %load)
42+
%fabst = call <4 x double> @llvm.matrix.transpose(<4 x double> %fabs, i32 2, i32 2)
43+
%fabstt = call <4 x double> @llvm.matrix.transpose(<4 x double> %fabst, i32 2, i32 2)
44+
store <4 x double> %fabstt, ptr %out
45+
ret void
46+
}
47+
48+
define void @fabs_2x2i32(ptr %in, ptr %out) {
49+
; CHECK-LABEL: @fabs_2x2i32(
50+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[IN:%.*]], align 16
51+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 2
52+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
53+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[COL_LOAD]], i1 false)
54+
; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[COL_LOAD1]], i1 false)
55+
; CHECK-NEXT: [[TMP3:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP1]], i1 true)
56+
; CHECK-NEXT: [[TMP4:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP2]], i1 true)
57+
; CHECK-NEXT: store <2 x i32> [[TMP3]], ptr [[OUT:%.*]], align 16
58+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
59+
; CHECK-NEXT: store <2 x i32> [[TMP4]], ptr [[VEC_GEP2]], align 8
60+
; CHECK-NEXT: ret void
61+
;
62+
%load = load <4 x i32>, ptr %in
63+
%abs = call <4 x i32> @llvm.abs.v4i32(<4 x i32> %load, i1 false)
64+
%abst = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %abs, i32 2, i32 2)
65+
%abstt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %abst, i32 2, i32 2)
66+
%absabstt = call <4 x i32> @llvm.abs.v4i32(<4 x i32> %abstt, i1 true)
67+
store <4 x i32> %absabstt, ptr %out
68+
ret void
69+
}

0 commit comments

Comments
 (0)