Skip to content

[NFC] [AArch64] Refactor predicate register class decode functions #97412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,25 +160,30 @@ def GPR64common : RegisterClass<"AArch64", [i64], 64,
(add (sequence "X%u", 0, 28), FP, LR)> {
let AltOrders = [(rotl GPR64common, 8)];
let AltOrderSelect = [{ return 1; }];
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::GPR64commonRegClassID, 0, 31>";
}
// GPR register classes which exclude SP/WSP.
def GPR32 : RegisterClass<"AArch64", [i32], 32, (add GPR32common, WZR)> {
let AltOrders = [(rotl GPR32, 8)];
let AltOrderSelect = [{ return 1; }];
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::GPR32RegClassID, 0, 32>";
}
def GPR64 : RegisterClass<"AArch64", [i64], 64, (add GPR64common, XZR)> {
let AltOrders = [(rotl GPR64, 8)];
let AltOrderSelect = [{ return 1; }];
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::GPR64RegClassID, 0, 32>";
}

// GPR register classes which include SP/WSP.
def GPR32sp : RegisterClass<"AArch64", [i32], 32, (add GPR32common, WSP)> {
let AltOrders = [(rotl GPR32sp, 8)];
let AltOrderSelect = [{ return 1; }];
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::GPR32spRegClassID, 0, 32>";
}
def GPR64sp : RegisterClass<"AArch64", [i64], 64, (add GPR64common, SP)> {
let AltOrders = [(rotl GPR64sp, 8)];
let AltOrderSelect = [{ return 1; }];
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::GPR64spRegClassID,0, 32>";
}

def GPR32sponly : RegisterClass<"AArch64", [i32], 32, (add WSP)>;
Expand Down Expand Up @@ -446,18 +451,24 @@ def Q31 : AArch64Reg<31, "q31", [D31], ["v31", ""]>, DwarfRegAlias<B31>;

def FPR8 : RegisterClass<"AArch64", [i8], 8, (sequence "B%u", 0, 31)> {
let Size = 8;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR8RegClassID, 0, 32>";
}
def FPR16 : RegisterClass<"AArch64", [f16, bf16, i16], 16, (sequence "H%u", 0, 31)> {
let Size = 16;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR16RegClassID, 0, 32>";
}

def FPR16_lo : RegisterClass<"AArch64", [f16], 16, (trunc FPR16, 16)> {
let Size = 16;
}
def FPR32 : RegisterClass<"AArch64", [f32, i32], 32,(sequence "S%u", 0, 31)>;
def FPR32 : RegisterClass<"AArch64", [f32, i32], 32,(sequence "S%u", 0, 31)> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR32RegClassID, 0, 32>";
}
def FPR64 : RegisterClass<"AArch64", [f64, i64, v2f32, v1f64, v8i8, v4i16, v2i32,
v1i64, v4f16, v4bf16],
64, (sequence "D%u", 0, 31)>;
64, (sequence "D%u", 0, 31)> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR64RegClassID, 0, 32>";
}
def FPR64_lo : RegisterClass<"AArch64",
[v8i8, v4i16, v2i32, v1i64, v4f16, v4bf16, v2f32,
v1f64],
Expand All @@ -469,21 +480,27 @@ def FPR64_lo : RegisterClass<"AArch64",
def FPR128 : RegisterClass<"AArch64",
[v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, f128,
v8f16, v8bf16],
128, (sequence "Q%u", 0, 31)>;
128, (sequence "Q%u", 0, 31)> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR128RegClassID, 0, 32>";
}

// The lower 16 vector registers. Some instructions can only take registers
// in this range.
def FPR128_lo : RegisterClass<"AArch64",
[v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16,
v8bf16],
128, (trunc FPR128, 16)>;
128, (trunc FPR128, 16)> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR128RegClassID, 0, 16>";
}

// The lower 8 vector registers. Some instructions can only take registers
// in this range.
def FPR128_0to7 : RegisterClass<"AArch64",
[v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16,
v8bf16],
128, (trunc FPR128, 8)>;
128, (trunc FPR128, 8)> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR128RegClassID, 0, 8>";
}

// Pairs, triples, and quads of 64-bit vector registers.
def DSeqPairs : RegisterTuples<[dsub0, dsub1], [(rotl FPR64, 0), (rotl FPR64, 1)]>;
Expand All @@ -495,12 +512,15 @@ def DSeqQuads : RegisterTuples<[dsub0, dsub1, dsub2, dsub3],
(rotl FPR64, 2), (rotl FPR64, 3)]>;
def DD : RegisterClass<"AArch64", [untyped], 64, (add DSeqPairs)> {
let Size = 128;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::DDRegClassID, 0, 32>";
}
def DDD : RegisterClass<"AArch64", [untyped], 64, (add DSeqTriples)> {
let Size = 192;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::DDDRegClassID, 0, 32>";
}
def DDDD : RegisterClass<"AArch64", [untyped], 64, (add DSeqQuads)> {
let Size = 256;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::DDDDRegClassID, 0, 32>";
}

// Pairs, triples, and quads of 128-bit vector registers.
Expand All @@ -513,12 +533,15 @@ def QSeqQuads : RegisterTuples<[qsub0, qsub1, qsub2, qsub3],
(rotl FPR128, 2), (rotl FPR128, 3)]>;
def QQ : RegisterClass<"AArch64", [untyped], 128, (add QSeqPairs)> {
let Size = 256;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::QQRegClassID, 0, 32>";
}
def QQQ : RegisterClass<"AArch64", [untyped], 128, (add QSeqTriples)> {
let Size = 384;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::QQQRegClassID, 0, 32>";
}
def QQQQ : RegisterClass<"AArch64", [untyped], 128, (add QSeqQuads)> {
let Size = 512;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::QQQQRegClassID, 0, 32>";
}


Expand Down Expand Up @@ -904,9 +927,15 @@ class PPRClass<int firstreg, int lastreg> : RegisterClass<
let Size = 16;
}

def PPR : PPRClass<0, 15>;
def PPR_3b : PPRClass<0, 7>; // Restricted 3 bit SVE predicate register class.
def PPR_p8to15 : PPRClass<8, 15>;
def PPR : PPRClass<0, 15> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PPRRegClassID, 0, 16>";
}
def PPR_3b : PPRClass<0, 7> { // Restricted 3 bit SVE predicate register class.
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PPRRegClassID, 0, 8>";
}
def PPR_p8to15 : PPRClass<8, 15> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PNRRegClassID, 8, 8>";
}

class PPRAsmOperand <string name, string RegClass, int Width>: AsmOperandClass {
let Name = "SVE" # name # "Reg";
Expand Down Expand Up @@ -941,7 +970,9 @@ class PNRClass<int firstreg, int lastreg> : RegisterClass<
let Size = 16;
}

def PNR : PNRClass<0, 15>;
def PNR : PNRClass<0, 15> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PNRRegClassID, 0, 16>";
}
def PNR_3b : PNRClass<0, 7>;
def PNR_p8to15 : PNRClass<8, 15>;

Expand Down Expand Up @@ -982,7 +1013,7 @@ class PNRP8to15RegOp<string Suffix, AsmOperandClass C, int Width, RegisterClass
: SVERegOp<Suffix, C, ElementSizeNone, RC> {
let PrintMethod = "printPredicateAsCounter<" # Width # ">";
let EncoderMethod = "EncodePNR_p8to15";
let DecoderMethod = "DecodePNR_p8to15RegisterClass";
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PNRRegClassID, 8, 8>";
}

def PNRAny_p8to15 : PNRP8to15RegOp<"", PNRAsmAny_p8to15, 0, PNR_p8to15>;
Expand Down Expand Up @@ -1013,7 +1044,9 @@ class PPRorPNRAsmOperand<string name, string RegClass, int Width>: AsmOperandCla
let ParserMethod = "tryParseSVEPredicateOrPredicateAsCounterVector";
}

def PPRorPNR : PPRorPNRClass;
def PPRorPNR : PPRorPNRClass {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PPRorPNRRegClassID, 0, 16>";
}
def PPRorPNRAsmOp8 : PPRorPNRAsmOperand<"PPRorPNRB", "PPRorPNR", 8>;
def PPRorPNRAsmOpAny : PPRorPNRAsmOperand<"PPRorPNRAny", "PPRorPNR", 0>;
def PPRorPNRAny : PPRRegOp<"", PPRorPNRAsmOpAny, ElementSizeNone, PPRorPNR>;
Expand All @@ -1024,6 +1057,7 @@ def PSeqPairs : RegisterTuples<[psub0, psub1], [(rotl PPR, 0), (rotl PPR, 1)]>;

def PPR2 : RegisterClass<"AArch64", [untyped], 16, (add PSeqPairs)> {
let Size = 32;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::PPR2RegClassID, 0, 16>";
}

class PPRVectorList<int ElementWidth, int NumRegs> : AsmOperandClass {
Expand Down Expand Up @@ -1097,9 +1131,15 @@ class ZPRClass<int lastreg> : RegisterClass<"AArch64",
let Size = 128;
}

def ZPR : ZPRClass<31>;
def ZPR_4b : ZPRClass<15>; // Restricted 4 bit SVE vector register class.
def ZPR_3b : ZPRClass<7>; // Restricted 3 bit SVE vector register class.
def ZPR : ZPRClass<31> {
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPRRegClassID, 0, 32>";
}
def ZPR_4b : ZPRClass<15> { // Restricted 4 bit SVE vector register class.
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPRRegClassID, 0, 16>";
}
def ZPR_3b : ZPRClass<7> { // Restricted 3 bit SVE vector register class.
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPRRegClassID, 0, 8>";
}

class ZPRAsmOperand<string name, int Width, string RegClassSuffix = "">
: AsmOperandClass {
Expand Down Expand Up @@ -1176,12 +1216,15 @@ def ZSeqQuads : RegisterTuples<[zsub0, zsub1, zsub2, zsub3], [(rotl ZPR, 0), (

def ZPR2 : RegisterClass<"AArch64", [untyped], 128, (add ZSeqPairs)> {
let Size = 256;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR2RegClassID, 0, 32>";
}
def ZPR3 : RegisterClass<"AArch64", [untyped], 128, (add ZSeqTriples)> {
let Size = 384;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR3RegClassID, 0, 32>";
}
def ZPR4 : RegisterClass<"AArch64", [untyped], 128, (add ZSeqQuads)> {
let Size = 512;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR4RegClassID, 0, 32>";
}

class ZPRVectorList<int ElementWidth, int NumRegs> : AsmOperandClass {
Expand Down Expand Up @@ -1379,10 +1422,12 @@ def ZStridedQuadsHi : RegisterTuples<[zsub0, zsub1, zsub2, zsub3], [
def ZPR2Strided : RegisterClass<"AArch64", [untyped], 128,
(add ZStridedPairsLo, ZStridedPairsHi)> {
let Size = 256;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR2StridedRegClassID, 0, 16>";
}
def ZPR4Strided : RegisterClass<"AArch64", [untyped], 128,
(add ZStridedQuadsLo, ZStridedQuadsHi)> {
let Size = 512;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR4StridedRegClassID, 0, 8>";
}

def ZPR2StridedOrContiguous : RegisterClass<"AArch64", [untyped], 128,
Expand All @@ -1401,7 +1446,7 @@ class ZPRVectorListStrided<int ElementWidth, int NumRegs, int Stride>
}

let EncoderMethod = "EncodeZPR2StridedRegisterClass",
DecoderMethod = "DecodeZPR2StridedRegisterClass" in {
DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR2StridedRegClassID, 0, 16>" in {
def ZZ_b_strided
: RegisterOperand<ZPR2Strided, "printTypedVectorList<0, 'b'>"> {
let ParserMatchClass = ZPRVectorListStrided<8, 2, 8>;
Expand Down Expand Up @@ -1439,7 +1484,7 @@ def ZPR4StridedOrContiguous : RegisterClass<"AArch64", [untyped], 128,
}

let EncoderMethod = "EncodeZPR4StridedRegisterClass",
DecoderMethod = "DecodeZPR4StridedRegisterClass" in {
DecoderMethod = "DecodeSimpleRegisterClass<AArch64::ZPR4StridedRegClassID, 0, 16>" in {
def ZZZZ_b_strided
: RegisterOperand<ZPR4Strided, "printTypedVectorList<0,'b'>"> {
let ParserMatchClass = ZPRVectorListStrided<8, 4, 4>;
Expand Down Expand Up @@ -1774,9 +1819,11 @@ def MatrixTileList : MatrixTileListOperand<>;

def MatrixIndexGPR32_8_11 : RegisterClass<"AArch64", [i32], 32, (sequence "W%u", 8, 11)> {
let DiagnosticType = "InvalidMatrixIndexGPR32_8_11";
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::MatrixIndexGPR32_8_11RegClassID, 0, 4>";
}
def MatrixIndexGPR32_12_15 : RegisterClass<"AArch64", [i32], 32, (sequence "W%u", 12, 15)> {
let DiagnosticType = "InvalidMatrixIndexGPR32_12_15";
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::MatrixIndexGPR32_12_15RegClassID, 0, 4>";
}
def MatrixIndexGPR32Op8_11 : RegisterOperand<MatrixIndexGPR32_8_11> {
let EncoderMethod = "encodeMatrixIndexGPR32<AArch64::W8>";
Expand Down
Loading
Loading