Skip to content

[wasm] Jiterpreter monitoring phase take 2 #83489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 21, 2023
24 changes: 17 additions & 7 deletions src/mono/mono/mini/interp/interp.c
Original file line number Diff line number Diff line change
Expand Up @@ -3798,6 +3798,11 @@ max_d (double lhs, double rhs)
return fmax (lhs, rhs);
}

#if HOST_BROWSER
// Dummy call info used outside of monitoring phase. We don't care what's in it
static JiterpreterCallInfo jiterpreter_call_info = { 0 };
#endif

/*
* If CLAUSE_ARGS is non-null, start executing from it.
* The ERROR argument is used to avoid declaring an error object for every interp frame, its not used
Expand Down Expand Up @@ -7782,15 +7787,11 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
* (note that right now threading doesn't work, but it's worth being correct
* here so that implementing thread support will be easier later.)
*/
*mutable_ip = MINT_TIER_NOP_JITERPRETER;
mono_memory_barrier ();
*(volatile JiterpreterThunk*)(ip + 1) = prepare_result;
mono_memory_barrier ();
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
*mutable_ip = MINT_TIER_MONITOR_JITERPRETER;
// now execute the trace
// this isn't important for performance, but it makes it easier to use the
// jiterpreter early in automated tests where code only runs once
offset = prepare_result(frame, locals);
offset = prepare_result(frame, locals, &jiterpreter_call_info);
ip = (guint16*) (((guint8*)ip) + offset);
break;
}
Expand All @@ -7801,9 +7802,18 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_MONITOR_JITERPRETER) {
// The trace is in monitoring mode, where we track how far it actually goes
// each time it is executed for a while. After N more hits, we either
// turn it into an ENTER or a NOP depending on how well it is working
ptrdiff_t offset = mono_jiterp_monitor_trace (ip, frame, locals);
ip = (guint16*) (((guint8*)ip) + offset);
MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_ENTER_JITERPRETER) {
JiterpreterThunk thunk = (void*)READ32(ip + 1);
ptrdiff_t offset = thunk(frame, locals);
ptrdiff_t offset = thunk(frame, locals, &jiterpreter_call_info);
ip = (guint16*) (((guint8*)ip) + offset);
MINT_IN_BREAK;
}
Expand Down
125 changes: 107 additions & 18 deletions src/mono/mono/mini/interp/jiterpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -877,13 +877,17 @@ typedef struct {
// 64-bits because it can get very high if estimate heat is turned on
gint64 hit_count;
JiterpreterThunk thunk;
int penalty_total;
} TraceInfo;

#define MAX_TRACE_SEGMENTS 256
// The maximum number of trace segments used to store TraceInfo. This limits
// the maximum total number of traces to MAX_TRACE_SEGMENTS * TRACE_SEGMENT_SIZE
#define MAX_TRACE_SEGMENTS 1024
#define TRACE_SEGMENT_SIZE 1024

static volatile gint32 trace_count = 0;
static TraceInfo *trace_segments[MAX_TRACE_SEGMENTS] = { NULL };
static gint32 traces_rejected = 0;

static TraceInfo *
trace_info_allocate_segment (gint32 index) {
Expand Down Expand Up @@ -917,7 +921,14 @@ trace_info_get (gint32 index) {

static gint32
trace_info_alloc () {
gint32 index = trace_count++;
gint32 index = trace_count++,
limit = (MAX_TRACE_SEGMENTS * TRACE_SEGMENT_SIZE);
// Make sure we're not out of space in the trace info table.
if (index == limit)
g_print ("MONO_WASM: Reached maximum number of jiterpreter trace entry points (%d).\n", limit);
if (index >= limit)
return -1;

TraceInfo *info = trace_info_get (index);
info->hit_count = 0;
info->thunk = NULL;
Expand Down Expand Up @@ -984,20 +995,24 @@ jiterp_insert_entry_points (void *_imethod, void *_td)

if (enabled && should_generate) {
gint32 trace_index = trace_info_alloc ();

td->cbb = bb;
imethod->contains_traces = TRUE;
InterpInst *ins = mono_jiterp_insert_ins (td, NULL, MINT_TIER_PREPARE_JITERPRETER);
memcpy(ins->data, &trace_index, sizeof (trace_index));

// Clear the instruction counter
instruction_count = 0;

// Note that we only clear enter_at_next here, after generating a trace.
// This means that the flag will stay set intentionally if we keep failing
// to generate traces, perhaps due to a string of small basic blocks
// or multiple call instructions.
enter_at_next = bb->contains_call_instruction;
if (trace_index < 0) {
// We're out of space in the TraceInfo table.
return;
} else {
td->cbb = bb;
imethod->contains_traces = TRUE;
InterpInst *ins = mono_jiterp_insert_ins (td, NULL, MINT_TIER_PREPARE_JITERPRETER);
memcpy(ins->data, &trace_index, sizeof (trace_index));

// Clear the instruction counter
instruction_count = 0;

// Note that we only clear enter_at_next here, after generating a trace.
// This means that the flag will stay set intentionally if we keep failing
// to generate traces, perhaps due to a string of small basic blocks
// or multiple call instructions.
enter_at_next = bb->contains_call_instruction;
}
} else if (is_backwards_branch && enabled && !should_generate) {
// We failed to start a trace at a backwards branch target, but that might just mean
// that the loop body starts with one or two unsupported opcodes, so it may be
Expand Down Expand Up @@ -1233,7 +1248,7 @@ mono_jiterp_stelem_ref (

EMSCRIPTEN_KEEPALIVE int
mono_jiterp_trace_transfer (
int displacement, JiterpreterThunk trace, void *frame, void *pLocals
int displacement, JiterpreterThunk trace, void *frame, void *pLocals, JiterpreterCallInfo *cinfo
) {
// This indicates that we lost a race condition, so there's no trace to call. Just bail out.
// FIXME: Detect this at trace generation time and spin until the trace is available
Expand All @@ -1245,7 +1260,7 @@ mono_jiterp_trace_transfer (
// safepoint was already performed by the trace.
int relative_displacement = 0;
while (relative_displacement == 0)
relative_displacement = trace(frame, pLocals);
relative_displacement = trace(frame, pLocals, cinfo);

// We got a relative displacement other than 0, so the trace bailed out somewhere or
// branched to another branch target. Time to return (and our caller will return too.)
Expand Down Expand Up @@ -1326,6 +1341,80 @@ mono_jiterp_write_number_unaligned (void *dest, double value, int mode) {
}
}

#define TRACE_PENALTY_LIMIT 200
#define TRACE_MONITORING_DETAILED FALSE

ptrdiff_t
mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
{
gint32 index = READ32 (ip + 1);
TraceInfo *info = trace_info_get (index);
g_assert (info);

JiterpreterThunk thunk = info->thunk;
// FIXME: This shouldn't be possible
g_assert (((guint32)(void *)thunk) > JITERPRETER_NOT_JITTED);

JiterpreterCallInfo cinfo;
cinfo.backward_branch_taken = 0;
cinfo.bailout_opcode_count = -1;

InterpFrame *frame = _frame;

ptrdiff_t result = thunk (frame, locals, &cinfo);
// If a backward branch was taken, we can treat the trace as if it successfully
// executed at least one time. We don't know how long it actually ran, but back
// branches are almost always going to be loops. It's fine if a bailout happens
// after multiple loop iterations.
if (
(cinfo.bailout_opcode_count >= 0) &&
!cinfo.backward_branch_taken &&
(cinfo.bailout_opcode_count < mono_opt_jiterpreter_trace_monitoring_long_distance)
) {
// Start with a penalty of 2 and lerp all the way down to 0
float scaled = (float)(cinfo.bailout_opcode_count - mono_opt_jiterpreter_trace_monitoring_short_distance)
/ (mono_opt_jiterpreter_trace_monitoring_long_distance - mono_opt_jiterpreter_trace_monitoring_short_distance);
int penalty = MIN ((int)((1.0f - scaled) * TRACE_PENALTY_LIMIT), TRACE_PENALTY_LIMIT);
info->penalty_total += penalty;

// g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty);
}

gint64 hit_count = info->hit_count++ - mono_opt_jiterpreter_minimum_trace_hit_count;
if (hit_count == mono_opt_jiterpreter_trace_monitoring_period) {
// Prepare to enable the trace
volatile guint16 *mutable_ip = (volatile guint16*)ip;
*mutable_ip = MINT_TIER_NOP_JITERPRETER;

mono_memory_barrier ();
float average_penalty = info->penalty_total / (float)hit_count / 100.0f,
threshold = (mono_opt_jiterpreter_trace_monitoring_max_average_penalty / 100.0f);

if (average_penalty <= threshold) {
*(volatile JiterpreterThunk*)(ip + 1) = thunk;
mono_memory_barrier ();
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
if (mono_opt_jiterpreter_stats_enabled && TRACE_MONITORING_DETAILED)
g_print ("trace #%d @%d '%s' accepted; average_penalty %f <= %f\n", index, ip, frame->imethod->method->name, average_penalty, threshold);
} else {
traces_rejected++;
if (mono_opt_jiterpreter_stats_enabled) {
char * full_name = mono_method_get_full_name (frame->imethod->method);
g_print ("trace #%d @%d '%s' rejected; average_penalty %f > %f\n", index, ip, full_name, average_penalty, threshold);
g_free (full_name);
}
}
}

return result;
}

EMSCRIPTEN_KEEPALIVE gint32
mono_jiterp_get_rejected_trace_count ()
{
return traces_rejected;
}

// HACK: fix C4206
EMSCRIPTEN_KEEPALIVE
#endif // HOST_BROWSER
Expand Down
10 changes: 9 additions & 1 deletion src/mono/mono/mini/interp/jiterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
// NOT_JITTED indicates that the trace was not jitted and it should be turned into a NOP
#define JITERPRETER_NOT_JITTED 1

typedef const ptrdiff_t (*JiterpreterThunk) (void *frame, void *pLocals);
typedef struct {
gint32 backward_branch_taken;
gint32 bailout_opcode_count;
} JiterpreterCallInfo;

typedef const ptrdiff_t (*JiterpreterThunk) (void *frame, void *pLocals, JiterpreterCallInfo *cinfo);
typedef void (*WasmJitCallThunk) (void *ret_sp, void *sp, void *ftndesc, gboolean *thrown);
typedef void (*WasmDoJitCall) (gpointer cb, gpointer arg, gboolean *out_thrown);

Expand Down Expand Up @@ -139,6 +144,9 @@ mono_jiterp_imethod_to_ftnptr (InterpMethod *imethod);
void
mono_jiterp_enum_hasflag (MonoClass *klass, gint32 *dest, stackval *sp1, stackval *sp2);

ptrdiff_t
mono_jiterp_monitor_trace (const guint16 *ip, void *frame, void *locals);

#endif // __MONO_MINI_INTERPRETER_INTERNALS_H__

extern WasmDoJitCall jiterpreter_do_jit_call;
Expand Down
1 change: 1 addition & 0 deletions src/mono/mono/mini/interp/mintops.def
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ OPDEF(MINT_METADATA_UPDATE_LDFLDA, "metadata_update.ldflda", 5, 1, 1, MintOpTwoS
OPDEF(MINT_TIER_PREPARE_JITERPRETER, "tier_prepare_jiterpreter", 3, 0, 0, MintOpInt)
OPDEF(MINT_TIER_NOP_JITERPRETER, "tier_nop_jiterpreter", 3, 0, 0, MintOpInt)
OPDEF(MINT_TIER_ENTER_JITERPRETER, "tier_enter_jiterpreter", 3, 0, 0, MintOpInt)
OPDEF(MINT_TIER_MONITOR_JITERPRETER, "tier_monitor_jiterpreter", 3, 0, 0, MintOpInt)
#endif // HOST_BROWSER

IROPDEF(MINT_NOP, "nop", 1, 0, 0, MintOpNoArgs)
Expand Down
8 changes: 8 additions & 0 deletions src/mono/mono/utils/options-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ DEFINE_INT(jiterpreter_minimum_trace_length, "jiterpreter-minimum-trace-length",
DEFINE_INT(jiterpreter_minimum_distance_between_traces, "jiterpreter-minimum-distance-between-traces", 4, "Don't insert entry points closer together than this")
// once a trace entry point is inserted, we only actually JIT code for it once it's been hit this many times
DEFINE_INT(jiterpreter_minimum_trace_hit_count, "jiterpreter-minimum-trace-hit-count", 5000, "JIT trace entry points once they are hit this many times")
// trace prepares turn into a monitor opcode and stay one this long before being converted to enter or nop
DEFINE_INT(jiterpreter_trace_monitoring_period, "jiterpreter-trace-monitoring-period", 1000, "Monitor jitted traces for this many calls to determine whether to keep them")
// traces that process less than this many opcodes have a high exit penalty, more than this have a low exit penalty
DEFINE_INT(jiterpreter_trace_monitoring_short_distance, "jiterpreter-trace-monitoring-short-distance", 4, "Traces that exit after processing this many opcodes have a reduced exit penalty")
// traces that process this many opcodes have no exit penalty
DEFINE_INT(jiterpreter_trace_monitoring_long_distance, "jiterpreter-trace-monitoring-long-distance", 10, "Traces that exit after processing this many opcodes have no exit penalty")
// the average penalty value for a trace is compared against this threshold / 100 to decide whether to discard it
DEFINE_INT(jiterpreter_trace_monitoring_max_average_penalty, "jiterpreter-trace-monitoring-max-average-penalty", 75, "If the average penalty value for a trace is above this value it will be rejected")
// After a do_jit_call call site is hit this many times, we will queue it to be jitted
DEFINE_INT(jiterpreter_jit_call_trampoline_hit_count, "jiterpreter-jit-call-hit-count", 1000, "Queue specialized do_jit_call trampoline for JIT after this many hits")
// After a do_jit_call call site is hit this many times without being jitted, we will flush the JIT queue
Expand Down
2 changes: 2 additions & 0 deletions src/mono/wasm/runtime/cwraps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ const fn_signatures: SigLine[] = [
[true, "mono_jiterp_debug_count", "number", []],
[true, "mono_jiterp_get_trace_hit_count", "number", ["number"]],
[true, "mono_jiterp_get_polling_required_address", "number", []],
[true, "mono_jiterp_get_rejected_trace_count", "number", []],
...legacy_interop_cwraps
];

Expand Down Expand Up @@ -238,6 +239,7 @@ export interface t_Cwraps {
mono_jiterp_get_trace_hit_count(traceIndex: number): number;
mono_jiterp_get_polling_required_address(): Int32Ptr;
mono_jiterp_write_number_unaligned(destination: VoidPtr, value: number, mode: number): void;
mono_jiterp_get_rejected_trace_count(): number;
}

const wrapped_c_functions: t_Cwraps = <any>{};
Expand Down
37 changes: 37 additions & 0 deletions src/mono/wasm/runtime/jiterpreter-support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,17 @@ class Cfg {
const disp = this.dispatchTable.get(segment.target)!;
if (this.trace)
console.log(`backward br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)}: disp=${disp}`);

// set the backward branch taken flag in the cinfo so that the monitoring phase
// knows we took a backward branch. this is unfortunate but unavoidable overhead
// we just make it a flag instead of an increment to reduce the cost
this.builder.local("cinfo");
// TODO: Store the offset in opcodes instead? Probably not useful information
this.builder.i32_const(1);
this.builder.appendU8(WasmOpcode.i32_store);
this.builder.appendMemarg(0, 0); // JiterpreterCallInfo.backward_branch_taken

// set the dispatch index for the br_table
this.builder.i32_const(disp);
this.builder.local("disp", WasmOpcode.set_local);
} else {
Expand Down Expand Up @@ -1276,6 +1287,24 @@ export function append_bailout (builder: WasmBuilder, ip: MintOpcodePtr, reason:
builder.appendU8(WasmOpcode.return_);
}

// generate a bailout that is recorded for the monitoring phase as a possible early exit.
export function append_exit (builder: WasmBuilder, ip: MintOpcodePtr, opcodeCounter: number, reason: BailoutReason) {
if (opcodeCounter <= (builder.options.monitoringLongDistance + 1)) {
builder.local("cinfo");
builder.i32_const(opcodeCounter);
builder.appendU8(WasmOpcode.i32_store);
builder.appendMemarg(4, 0); // bailout_opcode_count
}

builder.ip_const(ip);
if (builder.options.countBailouts) {
builder.i32_const(builder.base);
builder.i32_const(reason);
builder.callImport("bailout");
}
builder.appendU8(WasmOpcode.return_);
}

export function copyIntoScratchBuffer (src: NativePointer, size: number) : NativePointer {
if (!scratchBuffer)
scratchBuffer = Module._malloc(64);
Expand Down Expand Up @@ -1551,6 +1580,10 @@ export type JiterpreterOptions = {
eliminateNullChecks: boolean;
minimumTraceLength: number;
minimumTraceHitCount: number;
monitoringPeriod: number;
monitoringShortDistance: number;
monitoringLongDistance: number;
monitoringMaxAveragePenalty: number;
jitCallHitCount: number;
jitCallFlushThreshold: number;
interpEntryHitCount: number;
Expand All @@ -1577,6 +1610,10 @@ const optionNames : { [jsName: string] : string } = {
"directJitCalls": "jiterpreter-direct-jit-calls",
"minimumTraceLength": "jiterpreter-minimum-trace-length",
"minimumTraceHitCount": "jiterpreter-minimum-trace-hit-count",
"monitoringPeriod": "jiterpreter-trace-monitoring-period",
"monitoringShortDistance": "jiterpreter-trace-monitoring-short-distance",
"monitoringLongDistance": "jiterpreter-trace-monitoring-long-distance",
"monitoringMaxAveragePenalty": "jiterpreter-trace-monitoring-max-average-penalty",
"jitCallHitCount": "jiterpreter-jit-call-hit-count",
"jitCallFlushThreshold": "jiterpreter-jit-call-queue-flush-threshold",
"interpEntryHitCount": "jiterpreter-interp-entry-hit-count",
Expand Down
Loading