Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISC-V] Fix gc-related bugs in risc-v emitter #98226

Merged
150 changes: 86 additions & 64 deletions src/coreclr/jit/emitriscv64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,7 @@ unsigned emitter::emitOutput_Instr(BYTE* dst, code_t code) const
return sizeof(code_t);
}

static inline void assertCodeLength(unsigned code, uint8_t size)
static inline void assertCodeLength(size_t code, uint8_t size)
{
assert((code >> size) == 0);
}
Expand Down Expand Up @@ -2298,7 +2298,9 @@ static inline void assertCodeLength(unsigned code, uint8_t size)

static constexpr unsigned kInstructionOpcodeMask = 0x7f;
static constexpr unsigned kInstructionFunct3Mask = 0x7000;
static constexpr unsigned kInstructionFunct5Mask = 0xf8000000;
static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
static constexpr unsigned kInstructionFunct2Mask = 0x06000000;

#ifdef DEBUG

Expand Down Expand Up @@ -2338,34 +2340,44 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
assert(isGeneralRegisterOrR0(rs1));
assert(isGeneralRegisterOrR0(rs2));
break;
case INS_fadd_s:
case INS_fsub_s:
case INS_fmul_s:
case INS_fdiv_s:
case INS_fsgnj_s:
case INS_fsgnjn_s:
case INS_fsgnjx_s:
case INS_fmin_s:
case INS_fmax_s:
case INS_feq_s:
case INS_flt_s:
case INS_fle_s:
case INS_fadd_d:
case INS_fsub_d:
case INS_fmul_d:
case INS_fdiv_d:
case INS_fsgnj_d:
case INS_fsgnjn_d:
case INS_fsgnjx_d:
case INS_fmin_d:
case INS_fmax_d:
assert(isFloatReg(rd));
assert(isFloatReg(rs1));
assert(isFloatReg(rs2));
break;
case INS_feq_s:
case INS_feq_d:
case INS_flt_d:
case INS_flt_s:
case INS_fle_s:
case INS_fle_d:
assert(isFloatReg(rd));
assert(isGeneralRegisterOrR0(rd));
assert(isFloatReg(rs1));
assert(isFloatReg(rs2));
break;
case INS_fmv_w_x:
case INS_fmv_d_x:
assert(isFloatReg(rd));
assert(isGeneralRegisterOrR0(rs1));
assert(rs2 == 0);
break;
case INS_fmv_x_d:
case INS_fmv_x_w:
case INS_fclass_s:
case INS_fclass_d:
assert(isGeneralRegisterOrR0(rd));
assert(isFloatReg(rs1));
assert(rs2 == 0);
break;
default:
NO_WAY("Illegal ins within emitOutput_RTypeInstr!");
break;
Expand All @@ -2377,6 +2389,7 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
{
switch (ins)
{
case INS_mov:
case INS_jalr:
case INS_lb:
case INS_lh:
Expand All @@ -2392,7 +2405,6 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
case INS_lwu:
case INS_ld:
case INS_addiw:
case INS_fence_i:
case INS_csrrw:
case INS_csrrs:
case INS_csrrc:
Expand Down Expand Up @@ -2427,6 +2439,15 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
assert(rs1 < 32);
assert((opcode & kInstructionFunct7Mask) == 0);
break;
case INS_fence:
{
assert(rd == REG_ZERO);
assert(rs1 == REG_ZERO);
ssize_t format = immediate >> 8;
assert((format == 0) || (format == 0x8));
assert((opcode & kInstructionFunct7Mask) == 0);
}
break;
default:
NO_WAY("Illegal ins within emitOutput_ITypeInstr!");
break;
Expand Down Expand Up @@ -2867,7 +2888,7 @@ BYTE* emitter::emitOutputInstr_OptsI8(BYTE* dst, const instrDesc* id, ssize_t im
if (id->idReg2())
{
// special for INT64_MAX or UINT32_MAX
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, REG_R0, 0xfff);
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, REG_R0, NBitMask(12));
const ssize_t shiftValue = (immediate == INT64_MAX) ? 1 : 32;
dst += emitOutput_ITypeInstr(dst, INS_srli, reg1, reg1, shiftValue);
}
Expand All @@ -2881,10 +2902,10 @@ BYTE* emitter::emitOutputInstr_OptsI8(BYTE* dst, const instrDesc* id, ssize_t im

BYTE* emitter::emitOutputInstr_OptsI32(BYTE* dst, ssize_t immediate, regNumber reg1)
{
ssize_t upperWord = UpperWordOfDoubleWord(immediate);
const ssize_t upperWord = UpperWordOfDoubleWord(immediate);
dst += emitOutput_UTypeInstr(dst, INS_lui, reg1, UpperNBitsOfWordSignExtend<20>(upperWord));
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, reg1, LowerNBitsOfWord<12>(upperWord));
ssize_t lowerWord = LowerWordOfDoubleWord(immediate);
const ssize_t lowerWord = LowerWordOfDoubleWord(immediate);
dst += emitOutput_ITypeInstr(dst, INS_slli, reg1, reg1, 11);
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, reg1, LowerNBitsOfWord<11>(lowerWord >> 21));
dst += emitOutput_ITypeInstr(dst, INS_slli, reg1, reg1, 11);
Expand All @@ -2899,39 +2920,37 @@ BYTE* emitter::emitOutputInstr_OptsRc(BYTE* dst, const instrDesc* id, instructio
assert(id->idAddr()->iiaIsJitDataOffset());
assert(id->idGCref() == GCT_NONE);

int dataOffs = id->idAddr()->iiaGetJitDataOffset();
const int dataOffs = id->idAddr()->iiaGetJitDataOffset();
assert(dataOffs >= 0);

ssize_t immediate = emitGetInsSC(id);
const ssize_t immediate = emitGetInsSC(id);
assert((immediate >= 0) && (immediate < 0x4000)); // 0x4000 is arbitrary, currently 'imm' is always 0.

unsigned offset = static_cast<unsigned>(dataOffs + immediate);
const unsigned offset = static_cast<unsigned>(dataOffs + immediate);
assert(offset < emitDataSize());

*ins = id->idIns();
regNumber reg1 = id->idReg1();
*ins = id->idIns();
const regNumber reg1 = id->idReg1();

if (id->idIsReloc())
{
return emitOutputInstr_OptsRcReloc(dst, ins, reg1);
return emitOutputInstr_OptsRcReloc(dst, ins, offset, reg1);
}
return emitOutputInstr_OptsRcNoReloc(dst, ins, offset, reg1);
}

BYTE* emitter::emitOutputInstr_OptsRcReloc(BYTE* dst, instruction* ins, regNumber reg1)
BYTE* emitter::emitOutputInstr_OptsRcReloc(BYTE* dst, instruction* ins, unsigned offset, regNumber reg1)
{
ssize_t immediate = emitConsBlock - dst;
assert(immediate > 0);
assert((immediate & 0x03) == 0);
const ssize_t immediate = (emitConsBlock - dst) + offset;
assert((immediate > 0) && ((immediate & 0x03) == 0));

regNumber rsvdReg = codeGen->rsGetRsvdReg();
const regNumber rsvdReg = codeGen->rsGetRsvdReg();
dst += emitOutput_UTypeInstr(dst, INS_auipc, rsvdReg, UpperNBitsOfWordSignExtend<20>(immediate));

instruction lastIns = *ins;

if (*ins == INS_jal)
{
assert(isGeneralRegister(reg1));
*ins = lastIns = INS_addi;
}
dst += emitOutput_ITypeInstr(dst, lastIns, reg1, rsvdReg, LowerNBitsOfWord<12>(immediate));
Expand All @@ -2940,12 +2959,12 @@ BYTE* emitter::emitOutputInstr_OptsRcReloc(BYTE* dst, instruction* ins, regNumbe

BYTE* emitter::emitOutputInstr_OptsRcNoReloc(BYTE* dst, instruction* ins, unsigned offset, regNumber reg1)
{
ssize_t immediate = reinterpret_cast<ssize_t>(emitConsBlock) + offset;
assert((immediate >> 40) == 0);
regNumber rsvdReg = codeGen->rsGetRsvdReg();
const ssize_t immediate = reinterpret_cast<ssize_t>(emitConsBlock) + offset;
assertCodeLength(static_cast<size_t>(immediate), 40);
const regNumber rsvdReg = codeGen->rsGetRsvdReg();

instruction lastIns = (*ins == INS_jal) ? (*ins = INS_addi) : *ins;
UINT32 high = immediate >> 11;
const instruction lastIns = (*ins == INS_jal) ? (*ins = INS_addi) : *ins;
const UINT32 high = immediate >> 11;

dst += emitOutput_UTypeInstr(dst, INS_lui, rsvdReg, UpperNBitsOfWordSignExtend<20>(high));
dst += emitOutput_ITypeInstr(dst, INS_addi, rsvdReg, rsvdReg, LowerNBitsOfWord<12>(high));
Expand All @@ -2959,9 +2978,8 @@ BYTE* emitter::emitOutputInstr_OptsRl(BYTE* dst, instrDesc* id, instruction* ins
insGroup* targetInsGroup = static_cast<insGroup*>(emitCodeGetCookie(id->idAddr()->iiaBBlabel));
id->idAddr()->iiaIGlabel = targetInsGroup;

regNumber reg1 = id->idReg1();
assert(isGeneralRegister(reg1));
ssize_t igOffs = targetInsGroup->igOffs;
const regNumber reg1 = id->idReg1();
const ssize_t igOffs = targetInsGroup->igOffs;

if (id->idIsReloc())
{
Expand All @@ -2974,7 +2992,7 @@ BYTE* emitter::emitOutputInstr_OptsRl(BYTE* dst, instrDesc* id, instruction* ins

BYTE* emitter::emitOutputInstr_OptsRlReloc(BYTE* dst, ssize_t igOffs, regNumber reg1)
{
ssize_t immediate = (emitCodeBlock - dst) + igOffs;
const ssize_t immediate = (emitCodeBlock - dst) + igOffs;
assert((immediate & 0x03) == 0);

dst += emitOutput_UTypeInstr(dst, INS_auipc, reg1, UpperNBitsOfWordSignExtend<20>(immediate));
Expand All @@ -2984,11 +3002,11 @@ BYTE* emitter::emitOutputInstr_OptsRlReloc(BYTE* dst, ssize_t igOffs, regNumber

BYTE* emitter::emitOutputInstr_OptsRlNoReloc(BYTE* dst, ssize_t igOffs, regNumber reg1)
{
ssize_t immediate = reinterpret_cast<ssize_t>(emitCodeBlock) + igOffs;
assert((immediate >> (32 + 20)) == 0);
const ssize_t immediate = reinterpret_cast<ssize_t>(emitCodeBlock) + igOffs;
assertCodeLength(static_cast<size_t>(immediate), 32 + 20);

regNumber rsvdReg = codeGen->rsGetRsvdReg();
ssize_t upperSignExt = UpperWordOfDoubleWordDoubleSignExtend<32, 52>(immediate);
const regNumber rsvdReg = codeGen->rsGetRsvdReg();
const ssize_t upperSignExt = UpperWordOfDoubleWordDoubleSignExtend<32, 52>(immediate);

dst += emitOutput_UTypeInstr(dst, INS_lui, rsvdReg, UpperNBitsOfWordSignExtend<20>(immediate));
dst += emitOutput_ITypeInstr(dst, INS_addi, rsvdReg, rsvdReg, LowerNBitsOfWord<12>(immediate));
Expand All @@ -3000,32 +3018,32 @@ BYTE* emitter::emitOutputInstr_OptsRlNoReloc(BYTE* dst, ssize_t igOffs, regNumbe

BYTE* emitter::emitOutputInstr_OptsJalr(BYTE* dst, instrDescJmp* jmp, const insGroup* ig, instruction* ins)
{
ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, jmp) - 4;
const ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, jmp) - 4;
assert((immediate & 0x03) == 0);

*ins = jmp->idIns();
assert(jmp->idCodeSize() > 4); // The original INS_OPTS_JALR: not used by now!!!
switch (jmp->idCodeSize())
{
case 8:
return emitOutputInstr_OptsJalr8(dst, jmp, *ins, immediate);
return emitOutputInstr_OptsJalr8(dst, jmp, immediate);
case 24:
assert((*ins == INS_jal) || (*ins == INS_j));
assert(jmp->idInsIs(INS_jal, INS_j));
return emitOutputInstr_OptsJalr24(dst, immediate);
case 28:
return emitOutputInstr_OptsJalr28(dst, jmp, *ins, immediate);
return emitOutputInstr_OptsJalr28(dst, jmp, immediate);
default:
// case 0 - 4: The original INS_OPTS_JALR: not used by now!!!
break;
}
unreached();
return nullptr;
}

BYTE* emitter::emitOutputInstr_OptsJalr8(BYTE* dst, const instrDescJmp* jmp, instruction ins, ssize_t immediate)
BYTE* emitter::emitOutputInstr_OptsJalr8(BYTE* dst, const instrDescJmp* jmp, ssize_t immediate)
{
regNumber reg2 = ((ins != INS_beqz) && (ins != INS_bnez)) ? jmp->idReg2() : REG_R0;
const regNumber reg2 = jmp->idInsIs(INS_beqz, INS_bnez) ? REG_R0 : jmp->idReg2();

dst += emitOutput_BTypeInstr_InvertComparation(dst, ins, jmp->idReg1(), reg2, 0x8);
dst += emitOutput_BTypeInstr_InvertComparation(dst, jmp->idIns(), jmp->idReg1(), reg2, 0x8);
dst += emitOutput_JTypeInstr(dst, INS_jal, REG_ZERO, TrimSignedToImm21(immediate));
return dst;
}
Expand All @@ -3034,14 +3052,14 @@ BYTE* emitter::emitOutputInstr_OptsJalr24(BYTE* dst, ssize_t immediate)
{
// Make target address with offset, then jump (JALR) with the target address
immediate -= 2 * 4;
ssize_t high = UpperWordOfDoubleWordSingleSignExtend<0>(immediate);
const ssize_t high = UpperWordOfDoubleWordSingleSignExtend<0>(immediate);

dst += emitOutput_UTypeInstr(dst, INS_lui, REG_RA, UpperNBitsOfWordSignExtend<20>(high));
dst += emitOutput_ITypeInstr(dst, INS_addi, REG_RA, REG_RA, LowerNBitsOfWord<12>(high));
dst += emitOutput_ITypeInstr(dst, INS_slli, REG_RA, REG_RA, 32);

regNumber rsvdReg = codeGen->rsGetRsvdReg();
ssize_t low = LowerWordOfDoubleWord(immediate);
const regNumber rsvdReg = codeGen->rsGetRsvdReg();
const ssize_t low = LowerWordOfDoubleWord(immediate);

dst += emitOutput_UTypeInstr(dst, INS_auipc, rsvdReg, UpperNBitsOfWordSignExtend<20>(low));
dst += emitOutput_RTypeInstr(dst, INS_add, rsvdReg, REG_RA, rsvdReg);
Expand All @@ -3050,17 +3068,18 @@ BYTE* emitter::emitOutputInstr_OptsJalr24(BYTE* dst, ssize_t immediate)
return dst;
}

BYTE* emitter::emitOutputInstr_OptsJalr28(BYTE* dst, const instrDescJmp* jmp, instruction ins, ssize_t immediate)
BYTE* emitter::emitOutputInstr_OptsJalr28(BYTE* dst, const instrDescJmp* jmp, ssize_t immediate)
{
regNumber reg2 = ((ins != INS_beqz) && (ins != INS_bnez)) ? jmp->idReg2() : REG_R0;
dst += emitOutput_BTypeInstr_InvertComparation(dst, ins, jmp->idReg1(), reg2, 0x1c);
regNumber reg2 = jmp->idInsIs(INS_beqz, INS_bnez) ? REG_R0 : jmp->idReg2();

dst += emitOutput_BTypeInstr_InvertComparation(dst, jmp->idIns(), jmp->idReg1(), reg2, 0x1c);

return emitOutputInstr_OptsJalr24(dst, immediate);
}

BYTE* emitter::emitOutputInstr_OptsJCond(BYTE* dst, instrDesc* id, const insGroup* ig, instruction* ins)
{
ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));
const ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));

*ins = id->idIns();

Expand All @@ -3070,7 +3089,7 @@ BYTE* emitter::emitOutputInstr_OptsJCond(BYTE* dst, instrDesc* id, const insGrou

BYTE* emitter::emitOutputInstr_OptsJ(BYTE* dst, instrDesc* id, const insGroup* ig, instruction* ins)
{
ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));
const ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));
assert((immediate & 0x03) == 0);

*ins = id->idIns();
Expand Down Expand Up @@ -3133,11 +3152,13 @@ BYTE* emitter::emitOutputInstr_OptsC(BYTE* dst, instrDesc* id, const insGroup* i
size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
{
BYTE* dst = *dp;
BYTE* dst2 = dst + 4;
const BYTE* const odst = *dp;
instruction ins;
size_t sz = 0;

assert(REG_NA == static_cast<int>(REG_NA));
assert(writeableOffset == 0);

insOpts insOp = id->idInsOpt();

Expand Down Expand Up @@ -3174,8 +3195,9 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
sz = sizeof(instrDescJmp);
break;
case INS_OPTS_C:
dst = emitOutputInstr_OptsC(dst, id, ig, &sz);
ins = INS_nop;
dst = emitOutputInstr_OptsC(dst, id, ig, &sz);
dst2 = dst;
ins = INS_nop;
break;
default: // case INS_OPTS_NONE:
dst += emitOutput_Instr(dst, id->idAddr()->iiaGetInstrEncode());
Expand All @@ -3193,11 +3215,11 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
// We assume that "idReg1" is the primary destination register for all instructions
if (id->idGCref() != GCT_NONE)
{
emitGCregLiveUpd(id->idGCref(), id->idReg1(), dst);
emitGCregLiveUpd(id->idGCref(), id->idReg1(), dst2);
}
else
{
emitGCregDeadUpd(id->idReg1(), dst);
emitGCregDeadUpd(id->idReg1(), dst2);
}
}

Expand All @@ -3211,7 +3233,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
int adr = emitComp->lvaFrameAddress(varNum, &FPbased);
if (id->idGCref() != GCT_NONE)
{
emitGCvarLiveUpd(adr + ofs, varNum, id->idGCref(), dst DEBUG_ARG(varNum));
emitGCvarLiveUpd(adr + ofs, varNum, id->idGCref(), dst2 DEBUG_ARG(varNum));
}
else
{
Expand All @@ -3228,7 +3250,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
vt = tmpDsc->tdTempType();
}
if (vt == TYP_REF || vt == TYP_BYREF)
emitGCvarDeadUpd(adr + ofs, dst DEBUG_ARG(varNum));
emitGCvarDeadUpd(adr + ofs, dst2 DEBUG_ARG(varNum));
}
// if (emitInsWritesToLclVarStackLocPair(id))
//{
Expand Down
Loading
Loading