Skip to content

gh-119726: generate and patch AArch64 trampolines #123872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The JIT now generates more efficient code for calls to C functions resulting
in up to 0.8% memory savings and 1.5% speed improvement on AArch64. Patch by Diego Russo.
92 changes: 84 additions & 8 deletions Python/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "Python.h"

#include "pycore_abstract.h"
#include "pycore_bitutils.h"
#include "pycore_call.h"
#include "pycore_ceval.h"
#include "pycore_critical_section.h"
Expand Down Expand Up @@ -113,6 +114,21 @@ mark_executable(unsigned char *memory, size_t size)

// JIT compiler stuff: /////////////////////////////////////////////////////////

#define SYMBOL_MASK_WORDS 4

typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];

typedef struct {
unsigned char *mem;
symbol_mask mask;
size_t size;
} trampoline_state;

typedef struct {
trampoline_state trampolines;
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
} jit_state;

// Warning! AArch64 requires you to get your hands dirty. These are your gloves:

// value[value_start : value_start + len]
Expand Down Expand Up @@ -390,66 +406,126 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
patch_32r(location, value);
}

void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);

#include "jit_stencils.h"

#if defined(__aarch64__) || defined(_M_ARM64)
#define TRAMPOLINE_SIZE 16
#else
#define TRAMPOLINE_SIZE 0
#endif

// Generate and patch AArch64 trampolines. The symbols to jump to are stored
// in the jit_stencils.h in the symbols_map.
void
patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
{
// Masking is done modulo 32 as the mask is stored as an array of uint32_t
const uint32_t symbol_mask = 1 << (ordinal % 32);
const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
assert(symbol_mask & trampoline_mask);

// Count the number of set bits in the trampoline mask lower than ordinal,
// this gives the index into the array of trampolines.
int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
for (int i = 0; i < ordinal / 32; i++) {
index += _Py_popcount32(state->trampolines.mask[i]);
}

uint32_t *p = (uint32_t*)(state->trampolines.mem + index * TRAMPOLINE_SIZE);
assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);

uint64_t value = (uintptr_t)symbols_map[ordinal];

/* Generate the trampoline
0: 58000048 ldr x8, 8
4: d61f0100 br x8
8: 00000000 // The next two words contain the 64-bit address to jump to.
c: 00000000
*/
p[0] = 0x58000048;
p[1] = 0xD61F0100;
p[2] = value & 0xffffffff;
p[3] = value >> 32;

patch_aarch64_26r(location, (uintptr_t)p);
}

static void
combine_symbol_mask(const symbol_mask src, symbol_mask dest)
{
// Calculate the union of the trampolines required by each StencilGroup
for (size_t i = 0; i < SYMBOL_MASK_WORDS; i++) {
dest[i] |= src[i];
}
}

// Compiles executor in-place. Don't forget to call _PyJIT_Free later!
int
_PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length)
{
const StencilGroup *group;
// Loop once to find the total compiled size:
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
size_t code_size = 0;
size_t data_size = 0;
jit_state state = {};
group = &trampoline;
code_size += group->code_size;
data_size += group->data_size;
for (size_t i = 0; i < length; i++) {
const _PyUOpInstruction *instruction = &trace[i];
group = &stencil_groups[instruction->opcode];
instruction_starts[i] = code_size;
state.instruction_starts[i] = code_size;
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
}
group = &stencil_groups[_FATAL_ERROR];
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
// Calculate the size of the trampolines required by the whole trace
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
}
// Round up to the nearest page:
size_t page_size = get_page_size();
assert((page_size & (page_size - 1)) == 0);
size_t padding = page_size - ((code_size + data_size) & (page_size - 1));
size_t total_size = code_size + data_size + padding;
size_t padding = page_size - ((code_size + data_size + state.trampolines.size) & (page_size - 1));
size_t total_size = code_size + data_size + state.trampolines.size + padding;
unsigned char *memory = jit_alloc(total_size);
if (memory == NULL) {
return -1;
}
// Update the offsets of each instruction:
for (size_t i = 0; i < length; i++) {
instruction_starts[i] += (uintptr_t)memory;
state.instruction_starts[i] += (uintptr_t)memory;
}
// Loop again to emit the code:
unsigned char *code = memory;
unsigned char *data = memory + code_size;
state.trampolines.mem = memory + code_size + data_size;
// Compile the trampoline, which handles converting between the native
// calling convention and the calling convention used by jitted code
// (which may be different for efficiency reasons). On platforms where
// we don't change calling conventions, the trampoline is empty and
// nothing is emitted here:
group = &trampoline;
group->emit(code, data, executor, NULL, instruction_starts);
group->emit(code, data, executor, NULL, &state);
code += group->code_size;
data += group->data_size;
assert(trace[0].opcode == _START_EXECUTOR);
for (size_t i = 0; i < length; i++) {
const _PyUOpInstruction *instruction = &trace[i];
group = &stencil_groups[instruction->opcode];
group->emit(code, data, executor, instruction, instruction_starts);
group->emit(code, data, executor, instruction, &state);
code += group->code_size;
data += group->data_size;
}
// Protect against accidental buffer overrun into data:
group = &stencil_groups[_FATAL_ERROR];
group->emit(code, data, executor, NULL, instruction_starts);
group->emit(code, data, executor, NULL, &state);
code += group->code_size;
data += group->data_size;
assert(code == memory + code_size);
Expand Down
81 changes: 37 additions & 44 deletions Tools/jit/_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dataclasses
import enum
import sys
import typing

import _schema
Expand Down Expand Up @@ -103,8 +102,8 @@ class HoleValue(enum.Enum):
HoleValue.OPERAND_HI: "(instruction->operand >> 32)",
HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)",
HoleValue.TARGET: "instruction->target",
HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]",
HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]",
HoleValue.JUMP_TARGET: "state->instruction_starts[instruction->jump_target]",
HoleValue.ERROR_TARGET: "state->instruction_starts[instruction->error_target]",
HoleValue.ZERO: "",
}

Expand All @@ -125,6 +124,7 @@ class Hole:
symbol: str | None
# ...plus this addend:
addend: int
need_state: bool = False
func: str = dataclasses.field(init=False)
# Convenience method:
replace = dataclasses.replace
Expand Down Expand Up @@ -157,10 +157,12 @@ def as_c(self, where: str) -> str:
if value:
value += " + "
value += f"(uintptr_t)&{self.symbol}"
if _signed(self.addend):
if _signed(self.addend) or not value:
if value:
value += " + "
value += f"{_signed(self.addend):#x}"
if self.need_state:
return f"{self.func}({location}, {value}, state);"
return f"{self.func}({location}, {value});"


Expand All @@ -175,7 +177,6 @@ class Stencil:
body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False)

def pad(self, alignment: int) -> None:
"""Pad the stencil to the given alignment."""
Expand All @@ -184,39 +185,6 @@ def pad(self, alignment: int) -> None:
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
self.body.extend([0] * padding)

def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole:
"""Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
assert hole.symbol is not None
reuse_trampoline = hole.symbol in self.trampolines
if reuse_trampoline:
# Re-use the base address of the previously created trampoline
base = self.trampolines[hole.symbol]
else:
self.pad(alignment)
base = len(self.body)
new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA)

if reuse_trampoline:
return new_hole

self.disassembly += [
f"{base + 4 * 0:x}: 58000048 ldr x8, 8",
f"{base + 4 * 1:x}: d61f0100 br x8",
f"{base + 4 * 2:x}: 00000000",
f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}",
f"{base + 4 * 3:x}: 00000000",
]
for code in [
0x58000048.to_bytes(4, sys.byteorder),
0xD61F0100.to_bytes(4, sys.byteorder),
0x00000000.to_bytes(4, sys.byteorder),
0x00000000.to_bytes(4, sys.byteorder),
]:
self.body.extend(code)
self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64"))
self.trampolines[hole.symbol] = base
return new_hole

def remove_jump(self, *, alignment: int = 1) -> None:
"""Remove a zero-length continuation jump, if it exists."""
hole = max(self.holes, key=lambda hole: hole.offset)
Expand Down Expand Up @@ -282,18 +250,32 @@ class StencilGroup:
default_factory=dict, init=False
)
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)

def process_relocations(self, *, alignment: int = 1) -> None:
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)

def process_relocations(
self,
known_symbols: dict[str, int],
*,
alignment: int = 1,
) -> None:
"""Fix up all GOT and internal relocations for this stencil group."""
for hole in self.code.holes.copy():
if (
hole.kind
in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
and hole.value is HoleValue.ZERO
):
new_hole = self.data.emit_aarch64_trampoline(hole, alignment)
self.code.holes.remove(hole)
self.code.holes.append(new_hole)
hole.func = "patch_aarch64_trampoline"
hole.need_state = True
assert hole.symbol is not None
if hole.symbol in known_symbols:
ordinal = known_symbols[hole.symbol]
else:
ordinal = len(known_symbols)
known_symbols[hole.symbol] = ordinal
self._trampolines.add(ordinal)
hole.addend = ordinal
hole.symbol = None
self.code.remove_jump(alignment=alignment)
self.code.pad(alignment)
self.data.pad(8)
Expand Down Expand Up @@ -348,9 +330,20 @@ def _emit_global_offset_table(self) -> None:
)
self.data.body.extend([0] * 8)

def _get_trampoline_mask(self) -> str:
bitmask: int = 0
trampoline_mask: list[str] = []
for ordinal in self._trampolines:
bitmask |= 1 << ordinal
while bitmask:
word = bitmask & ((1 << 32) - 1)
trampoline_mask.append(f"{word:#04x}")
bitmask >>= 32
return "{" + ", ".join(trampoline_mask) + "}"

def as_c(self, opname: str) -> str:
"""Dump this hole as a StencilGroup initializer."""
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"


def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
Expand Down
7 changes: 5 additions & 2 deletions Tools/jit/_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class _Target(typing.Generic[_S, _R]):
stable: bool = False
debug: bool = False
verbose: bool = False
known_symbols: dict[str, int] = dataclasses.field(default_factory=dict)

def _compute_digest(self, out: pathlib.Path) -> str:
hasher = hashlib.sha256()
Expand Down Expand Up @@ -95,7 +96,9 @@ async def _parse(self, path: pathlib.Path) -> _stencils.StencilGroup:
if group.data.body:
line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
group.data.disassembly.append(line)
group.process_relocations(alignment=self.alignment)
group.process_relocations(
known_symbols=self.known_symbols, alignment=self.alignment
)
return group

def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None:
Expand Down Expand Up @@ -231,7 +234,7 @@ def build(
if comment:
file.write(f"// {comment}\n")
file.write("\n")
for line in _writer.dump(stencil_groups):
for line in _writer.dump(stencil_groups, self.known_symbols):
file.write(f"{line}\n")
try:
jit_stencils_new.replace(jit_stencils)
Expand Down
Loading
Loading