Skip to content

[RISC-V] Adjust trampoline code for branch control flow protection #141949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

jaidTw
Copy link
Contributor

@jaidTw jaidTw commented May 29, 2025

It is tricky to observe the trampoline code in the lit test file because instructions are encoded and written onto the stack. I don't have a better idea for doing it now. The stack of the test is organized as follow

   56 $ra
   48 $a0      f
   40 $a1      p
   36 00028067 jalr  t0
   32 000003b7 lui   t2, 0
   28 014e3e03 ld    t3, 20(t3)
   24 01ce3283 ld    t0, 28(t3)
   20 00000e17 auipc t3, 0
sp+16 00000017 lpad  0

@llvmbot
Copy link
Member

llvmbot commented May 29, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Jesse Huang (jaidTw)

Changes

It is tricky to observe the trampoline code from the lit test file, because instructions are encoded and written onto the stack
The stack of the test is organized as follow

   56 $ra
   48 $a0      f
   40 $a1      p
   36 00028067 jalr  t0
   32 000003b7 lui   t2, 0
   28 014e3e03 ld    t3, 20(t3)
   24 01ce3283 ld    t0, 28(t3)
   20 00000e17 auipc t3, 0
sp+16 00000017 lpad  0


---
Full diff: https://github.com/llvm/llvm-project/pull/141949.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+89-28) 
- (added) llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll (+95) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0a849f49116ee..2bcbe18e9beed 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -29,6 +29,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
@@ -8295,9 +8296,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
   //     16: <StaticChainOffset>
   //     24: <FunctionAddressOffset>
   //     32:
-
-  constexpr unsigned StaticChainOffset = 16;
-  constexpr unsigned FunctionAddressOffset = 24;
+  // Offset with branch control flow protection enabled:
+  //      0: lpad    <imm20>
+  //      4: auipc   t3, 0
+  //      8: ld      t0, 28(t3)
+  //     12: ld      t3, 20(t3)
+  //     16: lui     t2, <imm20>
+  //     20: jalr    t0
+  //     24: <StaticChainOffset>
+  //     32: <FunctionAddressOffset>
+  //     40:
+
+  const bool HasCFBranch =
+      Subtarget.hasStdExtZicfilp() &&
+      DAG.getMMI()->getModule()->getModuleFlag("cf-protection-branch");
+  const unsigned StaticChainIdx = HasCFBranch ? 6 : 4;
+  const unsigned StaticChainOffset = StaticChainIdx * 4;
+  const unsigned FunctionAddressOffset = StaticChainOffset + 8;
 
   const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
   assert(STI);
@@ -8310,35 +8325,77 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
   };
 
   SDValue OutChains[6];
-
-  uint32_t Encodings[] = {
-      // auipc t2, 0
-      // Loads the current PC into t2.
-      GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
-      // ld t0, 24(t2)
-      // Loads the function address into t0. Note that we are using offsets
-      // pc-relative to the first instruction of the trampoline.
-      GetEncoding(
-          MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
-              FunctionAddressOffset)),
-      // ld t2, 16(t2)
-      // Load the value of the static chain.
-      GetEncoding(
-          MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
-              StaticChainOffset)),
-      // jalr t0
-      // Jump to the function.
-      GetEncoding(MCInstBuilder(RISCV::JALR)
-                      .addReg(RISCV::X0)
-                      .addReg(RISCV::X5)
-                      .addImm(0))};
+  SDValue OutChainsLPAD[8];
+  if (HasCFBranch)
+    assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
+  else
+    assert(std::size(OutChains) == StaticChainIdx + 2);
+
+  SmallVector<uint32_t> Encodings;
+  if (!HasCFBranch) {
+    Encodings.append(
+        {// auipc t2, 0
+         // Loads the current PC into t2.
+         GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
+         // ld t0, 24(t2)
+         // Loads the function address into t0. Note that we are using offsets
+         // pc-relative to the first instruction of the trampoline.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X5)
+                         .addReg(RISCV::X7)
+                         .addImm(FunctionAddressOffset)),
+         // ld t2, 16(t2)
+         // Load the value of the static chain.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X7)
+                         .addReg(RISCV::X7)
+                         .addImm(StaticChainOffset)),
+         // jalr t0
+         // Jump to the function.
+         GetEncoding(MCInstBuilder(RISCV::JALR)
+                         .addReg(RISCV::X0)
+                         .addReg(RISCV::X5)
+                         .addImm(0))});
+  } else {
+    Encodings.append(
+        {// auipc x0, <imm20> (lpad <imm20>)
+         // Landing pad.
+         GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
+         // auipc t3, 0
+         // Loads the current PC into t3.
+         GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
+         // ld t0, (FunctionAddressOffset - 4)(t3)
+         // Loads the function address into t0. Note that we are using offsets
+         // pc-relative to the SECOND instruction of the trampoline.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X5)
+                         .addReg(RISCV::X28)
+                         .addImm(FunctionAddressOffset - 4)),
+         // ld t3, (StaticChainOffset - 4)(t3)
+         // Load the value of the static chain.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X28)
+                         .addReg(RISCV::X28)
+                         .addImm(StaticChainOffset - 4)),
+         // lui t2, <imm20>
+         // Setup the landing pad value.
+         GetEncoding(MCInstBuilder(RISCV::LUI).addReg(RISCV::X7).addImm(0)),
+         // jalr t0
+         // Jump to the function.
+         GetEncoding(MCInstBuilder(RISCV::JALR)
+                         .addReg(RISCV::X0)
+                         .addReg(RISCV::X5)
+                         .addImm(0))});
+  }
+
+  SDValue *OutChainsUsed = HasCFBranch ? OutChainsLPAD : OutChains;
 
   // Store encoded instructions.
   for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
     SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
                                          DAG.getConstant(Idx * 4, dl, MVT::i64))
                            : Trmp;
-    OutChains[Idx] = DAG.getTruncStore(
+    OutChainsUsed[Idx] = DAG.getTruncStore(
         Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
         MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
   }
@@ -8361,12 +8418,16 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
         DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
                     DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
     OffsetValue.Addr = Addr;
-    OutChains[Idx + 4] =
+    OutChainsUsed[Idx + StaticChainIdx] =
         DAG.getStore(Root, dl, OffsetValue.Value, Addr,
                      MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
   }
 
-  SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
+  SDValue StoreToken;
+  if (HasCFBranch)
+    StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChainsLPAD);
+  else
+    StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
 
   // The end of instructions of trampoline is the same as the static chain
   // address that we computed earlier.
diff --git a/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
new file mode 100644
index 0000000000000..304018ca0db56
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64 %s
+; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64-LINUX %s
+
+declare void @llvm.init.trampoline(ptr, ptr, ptr)
+declare ptr @llvm.adjust.trampoline(ptr)
+declare i64 @f(ptr nest, i64)
+
+define i64 @test0(i64 %n, ptr %p) nounwind {
+; RV64-LABEL: test0:
+; RV64:       # %bb.0:
+; RV64-NEXT:    lpad 0
+; RV64-NEXT:    addi sp, sp, -64
+; RV64-NEXT:    sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-NEXT:    sd a0, 0(sp) # 8-byte Folded Spill
+; RV64-NEXT:    lui a0, %hi(f)
+; RV64-NEXT:    addi a0, a0, %lo(f)
+; RV64-NEXT:    sd a0, 48(sp)
+; RV64-NEXT:    sd a1, 40(sp)
+; RV64-NEXT:    li a0, 951
+; RV64-NEXT:    sw a0, 32(sp)
+; RV64-NEXT:    li a0, 23
+; RV64-NEXT:    sw a0, 16(sp)
+; RV64-NEXT:    lui a0, 40
+; RV64-NEXT:    addiw a0, a0, 103
+; RV64-NEXT:    sw a0, 36(sp)
+; RV64-NEXT:    lui a0, 5348
+; RV64-NEXT:    addiw a0, a0, -509
+; RV64-NEXT:    sw a0, 28(sp)
+; RV64-NEXT:    lui a0, 7395
+; RV64-NEXT:    addiw a0, a0, 643
+; RV64-NEXT:    sw a0, 24(sp)
+; RV64-NEXT:    lui a0, 1
+; RV64-NEXT:    addiw a0, a0, -489
+; RV64-NEXT:    sw a0, 20(sp)
+; RV64-NEXT:    addi a1, sp, 40
+; RV64-NEXT:    addi a0, sp, 16
+; RV64-NEXT:    sd a0, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    call __clear_cache
+; RV64-NEXT:    ld a0, 0(sp) # 8-byte Folded Reload
+; RV64-NEXT:    ld a1, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    jalr a1
+; RV64-NEXT:    ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 64
+; RV64-NEXT:    ret
+;
+; RV64-LINUX-LABEL: test0:
+; RV64-LINUX:       # %bb.0:
+; RV64-LINUX-NEXT:    lpad 0
+; RV64-LINUX-NEXT:    addi sp, sp, -64
+; RV64-LINUX-NEXT:    sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT:    sd a0, 0(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT:    lui a0, %hi(f)
+; RV64-LINUX-NEXT:    addi a0, a0, %lo(f)
+; RV64-LINUX-NEXT:    sd a0, 48(sp)
+; RV64-LINUX-NEXT:    sd a1, 40(sp)
+; RV64-LINUX-NEXT:    li a0, 951
+; RV64-LINUX-NEXT:    sw a0, 32(sp)
+; RV64-LINUX-NEXT:    li a0, 23
+; RV64-LINUX-NEXT:    sw a0, 16(sp)
+; RV64-LINUX-NEXT:    lui a0, 40
+; RV64-LINUX-NEXT:    addiw a0, a0, 103
+; RV64-LINUX-NEXT:    sw a0, 36(sp)
+; RV64-LINUX-NEXT:    lui a0, 5348
+; RV64-LINUX-NEXT:    addiw a0, a0, -509
+; RV64-LINUX-NEXT:    sw a0, 28(sp)
+; RV64-LINUX-NEXT:    lui a0, 7395
+; RV64-LINUX-NEXT:    addiw a0, a0, 643
+; RV64-LINUX-NEXT:    sw a0, 24(sp)
+; RV64-LINUX-NEXT:    lui a0, 1
+; RV64-LINUX-NEXT:    addiw a0, a0, -489
+; RV64-LINUX-NEXT:    sw a0, 20(sp)
+; RV64-LINUX-NEXT:    addi a1, sp, 40
+; RV64-LINUX-NEXT:    addi a0, sp, 16
+; RV64-LINUX-NEXT:    sd a0, 8(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT:    li a2, 0
+; RV64-LINUX-NEXT:    call __riscv_flush_icache
+; RV64-LINUX-NEXT:    ld a0, 0(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT:    ld a1, 8(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT:    jalr a1
+; RV64-LINUX-NEXT:    ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT:    addi sp, sp, 64
+; RV64-LINUX-NEXT:    ret
+  %alloca = alloca [40 x i8], align 8
+  call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
+  %tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
+  %ret = call i64 %tramp(i64 %n)
+  ret i64 %ret
+}
+
+!llvm.module.flags = !{!0}
+
+!0 = !{i32 8, !"cf-protection-branch", i32 1}

@topperc topperc requested a review from rofirrim May 29, 2025 15:16
@jaidTw jaidTw force-pushed the trampoline_lpad branch from e43db77 to a286cf6 Compare June 2, 2025 11:24
@jaidTw
Copy link
Contributor Author

jaidTw commented Jun 2, 2025

There is a change to use software-guarded jump (t2) in the trampoline code, the new stack of the test is now

   56 $ra
   44 $a0      f
   36 $a1      p
   32 00038067 jalr  t2
   28 010e3e03 ld    t3, 16(t3)
   24 018e3383 ld    t2, 24(t3)
   20 00000e17 auipc t3, 0
sp+16 00000023 lpad  0

@jaidTw jaidTw requested a review from jrtc27 June 2, 2025 11:25
Copy link

github-actions bot commented Jun 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@jaidTw jaidTw force-pushed the trampoline_lpad branch 2 times, most recently from 5782af0 to c3de1c7 Compare June 3, 2025 16:15
@topperc
Copy link
Collaborator

topperc commented Jun 3, 2025

There is a change to use software-guarded jump (t2) in the trampoline code, the new stack of the test is now

   56 $ra
   44 $a0      f
   36 $a1      p
   32 00038067 jalr  t2
   28 010e3e03 ld    t3, 16(t3)
   24 018e3383 ld    t2, 24(t3)
   20 00000e17 auipc t3, 0
sp+16 00000023 lpad  0

Is gcc going to make the same change?

@jaidTw
Copy link
Contributor Author

jaidTw commented Jun 3, 2025

Is gcc going to make the same change?

Yes, They plan to change it but might not landed yet. @kito-cheng should know more

@jaidTw
Copy link
Contributor Author

jaidTw commented Jun 4, 2025

CI is failing on the new test, but I can't reproduce it locally yet

@jaidTw jaidTw force-pushed the trampoline_lpad branch from 686889b to 0779b88 Compare June 4, 2025 09:07
@kito-cheng
Copy link
Member

Is gcc going to make the same change?

Yes, the code sequence used in GCC was designed for our internal fixed-one label scheme, that's not work for func-sig based scheme, so the only way is using sw-guarded jump here I think.

Copy link
Collaborator

@rofirrim rofirrim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @jaidTw !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants