Skip to content

Commit

Permalink
[WebAssembly] Replace LOAD_SPLAT with SPLAT_VECTOR
Browse files Browse the repository at this point in the history
Splats were selected by matching on uses of `build_vector` with
identical elements, but a while back a target independent node for
vector splatting was added.
This removes the WebAssembly specific LOAD_SPLAT intrinsic, and instead
makes SPLAT_VECTOR legal and adds patterns for splat loads.

Differential Revision: https://reviews.llvm.org/D139871
  • Loading branch information
lukel97 committed Jan 4, 2023
1 parent a26cbd0 commit f841ad3
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 144 deletions.
1 change: 0 additions & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ HANDLE_NODETYPE(MEMORY_COPY)
HANDLE_NODETYPE(MEMORY_FILL)

// Memory intrinsics
HANDLE_MEM_NODETYPE(LOAD_SPLAT)
HANDLE_MEM_NODETYPE(GLOBAL_GET)
HANDLE_MEM_NODETYPE(GLOBAL_SET)
HANDLE_MEM_NODETYPE(TABLE_GET)
Expand Down
19 changes: 7 additions & 12 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
MVT::v2f64})
setOperationAction(ISD::VECTOR_SHUFFLE, T, Custom);

// Support splatting
for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
MVT::v2f64})
setOperationAction(ISD::SPLAT_VECTOR, T, Legal);

// Custom lowering since wasm shifts must have a scalar shift amount
for (auto Op : {ISD::SHL, ISD::SRA, ISD::SRL})
for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64})
Expand Down Expand Up @@ -2161,18 +2166,8 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
return IsConstant(Lane);
};
} else {
// Use a splat, but possibly a load_splat
LoadSDNode *SplattedLoad;
if ((SplattedLoad = dyn_cast<LoadSDNode>(SplatValue)) &&
SplattedLoad->getMemoryVT() == VecT.getVectorElementType()) {
Result = DAG.getMemIntrinsicNode(
WebAssemblyISD::LOAD_SPLAT, DL, DAG.getVTList(VecT),
{SplattedLoad->getChain(), SplattedLoad->getBasePtr(),
SplattedLoad->getOffset()},
SplattedLoad->getMemoryVT(), SplattedLoad->getMemOperand());
} else {
Result = DAG.getSplatBuildVector(VecT, DL, SplatValue);
}
// Use a splat (which might be selected as a load splat)
Result = DAG.getSplatBuildVector(VecT, DL, SplatValue);
IsLaneConstructed = [&SplatValue](size_t _, const SDValue &Lane) {
return Lane == SplatValue;
};
Expand Down
50 changes: 28 additions & 22 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,14 @@ def ImmI#SIZE : ImmLeaf<i32,
foreach SIZE = [2, 4, 8, 16, 32] in
def LaneIdx#SIZE : ImmLeaf<i32, "return 0 <= Imm && Imm < "#SIZE#";">;

// Create vector with identical lanes: splat
def splat2 : PatFrag<(ops node:$x), (build_vector $x, $x)>;
def splat4 : PatFrag<(ops node:$x), (build_vector $x, $x, $x, $x)>;
def splat8 : PatFrag<(ops node:$x), (build_vector $x, $x, $x, $x,
$x, $x, $x, $x)>;
def splat16 : PatFrag<(ops node:$x),
(build_vector $x, $x, $x, $x, $x, $x, $x, $x,
$x, $x, $x, $x, $x, $x, $x, $x)>;

class Vec {
ValueType vt;
ValueType int_vt;
ValueType lane_vt;
WebAssemblyRegClass lane_rc;
int lane_bits;
ImmLeaf lane_idx;
SDPatternOperator lane_load;
PatFrag splat;
string prefix;
Vec split;
Expand All @@ -82,7 +74,8 @@ def I8x16 : Vec {
let lane_rc = I32;
let lane_bits = 8;
let lane_idx = LaneIdx16;
let splat = splat16;
let lane_load = extloadi8;
let splat = PatFrag<(ops node:$x), (v16i8 (splat_vector (i8 $x)))>;
let prefix = "i8x16";
}

Expand All @@ -93,7 +86,8 @@ def I16x8 : Vec {
let lane_rc = I32;
let lane_bits = 16;
let lane_idx = LaneIdx8;
let splat = splat8;
let lane_load = extloadi16;
let splat = PatFrag<(ops node:$x), (v8i16 (splat_vector (i16 $x)))>;
let prefix = "i16x8";
let split = I8x16;
}
Expand All @@ -105,7 +99,8 @@ def I32x4 : Vec {
let lane_rc = I32;
let lane_bits = 32;
let lane_idx = LaneIdx4;
let splat = splat4;
let lane_load = load;
let splat = PatFrag<(ops node:$x), (v4i32 (splat_vector (i32 $x)))>;
let prefix = "i32x4";
let split = I16x8;
}
Expand All @@ -117,7 +112,8 @@ def I64x2 : Vec {
let lane_rc = I64;
let lane_bits = 64;
let lane_idx = LaneIdx2;
let splat = splat2;
let lane_load = load;
let splat = PatFrag<(ops node:$x), (v2i64 (splat_vector (i64 $x)))>;
let prefix = "i64x2";
let split = I32x4;
}
Expand All @@ -129,7 +125,8 @@ def F32x4 : Vec {
let lane_rc = F32;
let lane_bits = 32;
let lane_idx = LaneIdx4;
let splat = splat4;
let lane_load = load;
let splat = PatFrag<(ops node:$x), (v4f32 (splat_vector (f32 $x)))>;
let prefix = "f32x4";
}

Expand All @@ -140,7 +137,8 @@ def F64x2 : Vec {
let lane_rc = F64;
let lane_bits = 64;
let lane_idx = LaneIdx2;
let splat = splat2;
let lane_load = load;
let splat = PatFrag<(ops node:$x), (v2f64 (splat_vector (f64 $x)))>;
let prefix = "f64x2";
}

Expand Down Expand Up @@ -195,14 +193,11 @@ defm "" : SIMDLoadSplat<16, 8>;
defm "" : SIMDLoadSplat<32, 9>;
defm "" : SIMDLoadSplat<64, 10>;

def wasm_load_splat_t : SDTypeProfile<1, 1, [SDTCisPtrTy<1>]>;
def wasm_load_splat : SDNode<"WebAssemblyISD::LOAD_SPLAT", wasm_load_splat_t,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
def load_splat : PatFrag<(ops node:$addr), (wasm_load_splat node:$addr)>;

foreach vec = AllVecs in {
defvar inst = "LOAD"#vec.lane_bits#"_SPLAT";
defm : LoadPat<vec.vt, load_splat, inst>;
defvar inst = "LOAD"#vec.lane_bits#"_SPLAT";
defm : LoadPat<vec.vt,
PatFrag<(ops node:$addr), (splat_vector (vec.lane_vt (vec.lane_load node:$addr)))>,
inst>;
}

// Load and extend
Expand Down Expand Up @@ -488,6 +483,17 @@ defm "" : ConstVec<F64x2,
(build_vector (f64 fpimm:$i0), (f64 fpimm:$i1)),
"$i0, $i1">;

// Match splat(x) -> const.v128(x, ..., x)
foreach vec = AllVecs in {
defvar numEls = !div(vec.vt.Size, vec.lane_bits);
defvar isFloat = !or(!eq(vec.lane_vt, f32), !eq(vec.lane_vt, f64));
defvar immKind = !if(isFloat, fpimm, imm);
def : Pat<(vec.splat (vec.lane_vt immKind:$x)),
!dag(!cast<NI>("CONST_V128_"#vec),
!listsplat((vec.lane_vt immKind:$x), numEls),
?)>;
}

// Shuffle lanes: shuffle
defm SHUFFLE :
SIMD_I<(outs V128:$dst),
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ define <2 x i16> @stest_f64i16(<2 x double> %x) {
; CHECK-NEXT: f64x2.extract_lane 1
; CHECK-NEXT: i32.trunc_sat_f64_s
; CHECK-NEXT: i32x4.replace_lane 1
; CHECK-NEXT: v128.const 32767, 32767, 0, 0
; CHECK-NEXT: v128.const 32767, 32767, 32767, 32767
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: v128.const -32768, -32768, 0, 0
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: local.get 0
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
Expand All @@ -326,7 +326,7 @@ define <2 x i16> @utest_f64i16(<2 x double> %x) {
; CHECK-NEXT: f64x2.extract_lane 1
; CHECK-NEXT: i32.trunc_sat_f64_u
; CHECK-NEXT: i32x4.replace_lane 1
; CHECK-NEXT: v128.const 65535, 65535, 0, 0
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
; CHECK-NEXT: i32x4.min_u
; CHECK-NEXT: local.get 0
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
Expand All @@ -351,7 +351,7 @@ define <2 x i16> @ustest_f64i16(<2 x double> %x) {
; CHECK-NEXT: f64x2.extract_lane 1
; CHECK-NEXT: i32.trunc_sat_f64_s
; CHECK-NEXT: i32x4.replace_lane 1
; CHECK-NEXT: v128.const 65535, 65535, 0, 0
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: v128.const 0, 0, 0, 0
; CHECK-NEXT: i32x4.max_s
Expand Down Expand Up @@ -1790,9 +1790,9 @@ define <2 x i16> @stest_f64i16_mm(<2 x double> %x) {
; CHECK-NEXT: f64x2.extract_lane 1
; CHECK-NEXT: i32.trunc_sat_f64_s
; CHECK-NEXT: i32x4.replace_lane 1
; CHECK-NEXT: v128.const 32767, 32767, 0, 0
; CHECK-NEXT: v128.const 32767, 32767, 32767, 32767
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: v128.const -32768, -32768, 0, 0
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: local.get 0
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
Expand All @@ -1817,7 +1817,7 @@ define <2 x i16> @utest_f64i16_mm(<2 x double> %x) {
; CHECK-NEXT: f64x2.extract_lane 1
; CHECK-NEXT: i32.trunc_sat_f64_u
; CHECK-NEXT: i32x4.replace_lane 1
; CHECK-NEXT: v128.const 65535, 65535, 0, 0
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
; CHECK-NEXT: i32x4.min_u
; CHECK-NEXT: local.get 0
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
Expand All @@ -1841,7 +1841,7 @@ define <2 x i16> @ustest_f64i16_mm(<2 x double> %x) {
; CHECK-NEXT: f64x2.extract_lane 1
; CHECK-NEXT: i32.trunc_sat_f64_s
; CHECK-NEXT: i32x4.replace_lane 1
; CHECK-NEXT: v128.const 65535, 65535, 0, 0
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: v128.const 0, 0, 0, 0
; CHECK-NEXT: i32x4.max_s
Expand Down
Loading

0 comments on commit f841ad3

Please sign in to comment.