Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kill time based heuristics no versions #1285

Merged
merged 15 commits into from
Jun 5, 2024
2 changes: 0 additions & 2 deletions rir/src/compiler/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,6 @@ rir::Function* Backend::doCompile(ClosureVersion* cls, ClosureLog& log) {
// here we only set the current version used to compile this function
auto feedback = rir::TypeFeedback::empty();
PROTECT(feedback->container());
feedback->version(
cls->optFunction->dispatchTable()->currentTypeFeedbackVersion());

function.finalize(body, signature, cls->context(), feedback);
for (auto& c : done)
Expand Down
4 changes: 1 addition & 3 deletions rir/src/compiler/native/builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ static FunctionSignature
deoptSentinelSig(FunctionSignature::Environment::CallerProvided,
FunctionSignature::OptimizationLevel::Optimized);
static Function* deoptSentinel;
static SEXP deoptSentinelContainer = []() {
SEXP deoptSentinelContainer = []() {
auto c = rir::Code::NewNative(0);
PROTECT(c->container());
SEXP store = Rf_allocVector(EXTERNALSXP, sizeof(Function));
Expand Down Expand Up @@ -1464,7 +1464,6 @@ static SEXP nativeCallTrampolineImpl(ArglistOrder::CallId callId, rir::Code* c,
R_ReturnedValue = R_NilValue; /* remove restart token */
fun->registerInvocation();
result = code->nativeCode()(code, args, env, callee);
fun->registerEndInvocation();
} else {
result = R_ReturnedValue;
}
Expand All @@ -1482,7 +1481,6 @@ static SEXP nativeCallTrampolineImpl(ArglistOrder::CallId callId, rir::Code* c,
ostack_popn(missing);

SLOWASSERT(t == R_BCNodeStackTop);
fun->registerEndInvocation();
return result;
}

Expand Down
78 changes: 31 additions & 47 deletions rir/src/compiler/native/lower_function_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ using namespace llvm;
extern "C" size_t R_NSize;
extern "C" size_t R_NodesInUse;

extern SEXP deoptSentinelContainer;

static_assert(sizeof(unsigned long) == sizeof(uint64_t),
"sizeof(unsigned long) and sizeof(uint64_t) should match");

Expand Down Expand Up @@ -3402,7 +3404,6 @@ void LowerFunctionLLVM::compile() {
auto calli = StaticCall::Cast(i);
calli->eachArg([](Value* v) { assert(!ExpandDots::Cast(v)); });
auto target = calli->tryDispatch();
auto bestTarget = calli->tryOptimisticDispatch();
std::vector<Value*> args;
calli->eachCallArg([&](Value* v) { args.push_back(v); });
Context asmpt = calli->inferAvailableAssumptions();
Expand All @@ -3424,32 +3425,34 @@ void LowerFunctionLLVM::compile() {
break;
}

if (target == bestTarget) {
auto callee = target->owner()->rirClosure();
auto dt = DispatchTable::check(BODY(callee));
rir::Function* nativeTarget = nullptr;
for (size_t i = 0; i < dt->size(); i++) {
auto entry = dt->get(i);
if (entry->context() == target->context() &&
entry->signature().numArguments >= args.size()) {
nativeTarget = entry;
}
auto callee = target->owner()->rirClosure();
auto dt = DispatchTable::check(BODY(callee));
rir::Function* nativeTarget = nullptr;
for (size_t i = 0; i < dt->size(); i++) {
auto entry = dt->get(i);
if (entry->context() == target->context() &&
entry->signature().numArguments >= args.size() &&
!entry->disabled()) {
nativeTarget = entry;
}
if (nativeTarget) {
assert(
asmpt.includes(Assumption::StaticallyArgmatched));
auto idx = Pool::makeSpace();
NativeBuiltins::targetCaches.push_back(idx);
Pool::patch(idx, nativeTarget->container());
auto missAsmptStore =
Rf_allocVector(RAWSXP, sizeof(Context));
auto missAsmptIdx = Pool::insert(missAsmptStore);
new (DATAPTR(missAsmptStore))
Context(nativeTarget->context() - asmpt);
assert(asmpt.smaller(nativeTarget->context()));
auto res = withCallFrame(args, [&]() {
return call(
NativeBuiltins::get(
}
SEXP container = deoptSentinelContainer;
if (nativeTarget) {
container = nativeTarget->container();
}

assert(asmpt.includes(Assumption::StaticallyArgmatched));
auto idx = Pool::makeSpace();
NativeBuiltins::targetCaches.push_back(idx);
Pool::patch(idx, container);
auto missAsmptStore = Rf_allocVector(RAWSXP, sizeof(Context));
auto missAsmptIdx = Pool::insert(missAsmptStore);
new (DATAPTR(missAsmptStore)) Context();
if (nativeTarget) {
assert(asmpt.smaller(nativeTarget->context()));
}
auto res = withCallFrame(args, [&]() {
return call(NativeBuiltins::get(
NativeBuiltins::Id::nativeCallTrampoline),
{
c(callId),
Expand All @@ -3462,27 +3465,8 @@ void LowerFunctionLLVM::compile() {
c(asmpt.toI()),
c(missAsmptIdx),
});
});
setVal(i, res);
break;
}
}

assert(asmpt.includes(Assumption::StaticallyArgmatched));
setVal(i, withCallFrame(args, [&]() -> llvm::Value* {
return call(
NativeBuiltins::get(NativeBuiltins::Id::call),
{
c(callId),
paramCode(),
c(calli->srcIdx),
builder.CreateIntToPtr(
c(calli->cls()->rirClosure()), t::SEXP),
loadSxp(calli->env()),
c(calli->nCallArgs()),
c(asmpt.toI()),
});
}));
});
setVal(i, res);
break;
}

Expand Down
13 changes: 8 additions & 5 deletions rir/src/compiler/rir2pir/rir2pir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,16 @@ bool Rir2Pir::compileBC(const BC& bc, Opcode* pos, Opcode* nextPos,
}

if (ti.taken != (size_t)-1 &&
insert.function->optFunction->invocationCount()) {
// the reason to take the baseline version is that we only
// increment the taken type feedback while running baseline
// FIXME: refactor
insert.function->owner()->rirFunction()->invocationCount()) {
if (auto c = CallInstruction::CastCall(top())) {
// invocation count is already incremented before calling jit
c->taken =
(double)ti.taken /
(double)(insert.function->optFunction->invocationCount() -
1);
c->taken = (double)ti.taken / (double)(insert.function->owner()
->rirFunction()
->invocationCount() -
1);
}
}
break;
Expand Down
7 changes: 0 additions & 7 deletions rir/src/interpreter/interp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,10 +818,6 @@ static void supplyMissingArgs(CallContext& call, const Function* fun) {

const unsigned pir::Parameter::PIR_WARMUP =
getenv("PIR_WARMUP") ? atoi(getenv("PIR_WARMUP")) : 100;
const unsigned pir::Parameter::PIR_OPT_TIME =
getenv("PIR_OPT_TIME") ? atoi(getenv("PIR_OPT_TIME")) : 3e6;
const unsigned pir::Parameter::PIR_REOPT_TIME =
getenv("PIR_REOPT_TIME") ? atoi(getenv("PIR_REOPT_TIME")) : 5e7;
const unsigned pir::Parameter::DEOPT_ABANDON =
getenv("PIR_DEOPT_ABANDON") ? atoi(getenv("PIR_DEOPT_ABANDON")) : 12;
const unsigned pir::Parameter::PIR_OPT_BC_SIZE =
Expand Down Expand Up @@ -1137,7 +1133,6 @@ SEXP doCall(CallContext& call, bool popArgs) {
assert(result);
if (popArgs)
ostack_popn(call.passedArgs - call.suppliedArgs);
fun->registerEndInvocation();
return result;
}
default:
Expand Down Expand Up @@ -3982,14 +3977,12 @@ SEXP rirEval(SEXP what, SEXP env) {
Function* fun = table->baseline();
fun->registerInvocation();
auto res = evalRirCodeExtCaller(fun->body(), env);
fun->registerEndInvocation();
return res;
}

if (auto fun = Function::check(what)) {
fun->registerInvocation();
auto res = evalRirCodeExtCaller(fun->body(), env);
fun->registerEndInvocation();
return res;
}

Expand Down
30 changes: 11 additions & 19 deletions rir/src/interpreter/interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ inline RCNTXT* findFunctionContextFor(SEXP e) {
return nullptr;
}

inline bool RecompileHeuristic(Function* fun,
Function* funMaybeDisabled = nullptr) {
inline bool RecompileHeuristic(Function* fun, Function* disabledFun = nullptr) {

auto flags = fun->flags;
if (flags.contains(Function::MarkOpt)) {
Expand All @@ -67,31 +66,24 @@ inline bool RecompileHeuristic(Function* fun,
if (flags.contains(Function::NotOptimizable))
return false;

if (!funMaybeDisabled)
funMaybeDisabled = fun;
if (!disabledFun)
disabledFun = fun;

auto abandon =
funMaybeDisabled->deoptCount() >= pir::Parameter::DEOPT_ABANDON;

auto wt = fun->isOptimized() ? pir::Parameter::PIR_REOPT_TIME
: pir::Parameter::PIR_OPT_TIME;
if (fun->invocationCount() >= 3 && fun->invocationTime() > wt) {
REC_HOOK(recording::recordInvocationCountTimeReason(
fun->invocationCount(), 3, fun->invocationTime(), wt));

fun->clearInvocationTime();
return !abandon;
}

if (abandon || fun->isOptimized())
if (disabledFun->deoptCount() >= pir::Parameter::DEOPT_ABANDON) {
return false;
}

auto wu = pir::Parameter::PIR_WARMUP;
if (wu == 0 || fun->invocationCount() == wu) {
if (wu == 0) {
REC_HOOK(recording::recordPirWarmupReason(wu));
return true;
}

if (fun->invocationCount() % wu == 0) {
REC_HOOK(recording::recordPirWarmupReason(fun->invocationCount()));
return true;
}

return false;
}

Expand Down
47 changes: 13 additions & 34 deletions rir/src/recording.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,6 @@
namespace rir {
namespace recording {

SEXP InvocationCountTimeReason::toSEXP() const {
auto vec = PROTECT(this->CompileReasonImpl::toSEXP());

size_t i = 0;
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(count));
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(minimalCount));
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(time));
SET_VECTOR_ELT(vec, i++, serialization::to_sexp(minimalTime));

UNPROTECT(1);
return vec;
}

void InvocationCountTimeReason::fromSEXP(SEXP sexp){
this->CompileReasonImpl::fromSEXP(sexp);

size_t i = 0;
this->count = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
this->minimalCount = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
this->time = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
this->minimalTime = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, i++));
}

SEXP PirWarmupReason::toSEXP() const {
auto vec = PROTECT(this->CompileReasonImpl::toSEXP());

Expand All @@ -57,10 +34,11 @@ SEXP PirWarmupReason::toSEXP() const {
return vec;
}

void PirWarmupReason::fromSEXP(SEXP sexp){
void PirWarmupReason::fromSEXP(SEXP sexp) {
this->CompileReasonImpl::fromSEXP(sexp);

this->invocationCount = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, 0));
this->invocationCount =
serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, 0));
}

SEXP OSRLoopReason::toSEXP() const {
Expand All @@ -72,7 +50,7 @@ SEXP OSRLoopReason::toSEXP() const {
return vec;
}

void OSRLoopReason::fromSEXP(SEXP sexp){
void OSRLoopReason::fromSEXP(SEXP sexp) {
this->CompileReasonImpl::fromSEXP(sexp);

this->loopCount = serialization::uint64_t_from_sexp(VECTOR_ELT(sexp, 0));
Expand Down Expand Up @@ -155,8 +133,8 @@ void Record::recordSpeculativeContext(const Code* code,
}
}

std::pair<size_t, FunRecording&> Record::initOrGetRecording(const SEXP cls,
const std::string& name) {
std::pair<size_t, FunRecording&>
Record::initOrGetRecording(const SEXP cls, const std::string& name) {
assert(Rf_isFunction(cls));
auto& body = *BODY(cls);

Expand Down Expand Up @@ -329,11 +307,13 @@ std::ostream& operator<<(std::ostream& out, const FunRecording& that) {
return out;
}

const char* ClosureEvent::targetName(const std::vector<FunRecording>& mapping) const {
const char*
ClosureEvent::targetName(const std::vector<FunRecording>& mapping) const {
return mapping[closureIndex].name.c_str();
}

const char* DtEvent::targetName(const std::vector<FunRecording>& mapping) const {
const char*
DtEvent::targetName(const std::vector<FunRecording>& mapping) const {
return mapping[dispatchTableIndex].name.c_str();
}

Expand Down Expand Up @@ -455,19 +435,19 @@ void CompilationEvent::print(const std::vector<FunRecording>& mapping,
out << "\n";
}
out << " ],\n opt_reasons=[\n";
if(this->compile_reasons.heuristic){
if (this->compile_reasons.heuristic) {
out << " heuristic=";
this->compile_reasons.heuristic->print(out);
out << "\n";
}

if(this->compile_reasons.condition){
if (this->compile_reasons.condition) {
out << " condition=";
this->compile_reasons.condition->print(out);
out << "\n";
}

if(this->compile_reasons.osr){
if (this->compile_reasons.osr) {
out << " osr_reason=";
this->compile_reasons.osr->print(out);
out << "\n";
Expand Down Expand Up @@ -681,7 +661,6 @@ void InvocationEvent::print(const std::vector<FunRecording>& mapping,
out << " }";
}


std::string getEnvironmentName(SEXP env) {
if (env == R_GlobalEnv) {
return GLOBAL_ENV_NAME;
Expand Down
28 changes: 0 additions & 28 deletions rir/src/recording.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,34 +99,6 @@ struct MarkOptReason : public CompileReasonImpl<MarkOptReason, 0> {
virtual ~MarkOptReason() = default;
};

struct InvocationCountTimeReason
: public CompileReasonImpl<InvocationCountTimeReason, 4> {
static constexpr const char* NAME = "InvocationCountTime";
virtual ~InvocationCountTimeReason() = default;

InvocationCountTimeReason(size_t count, size_t minimalCount,
unsigned long time, unsigned long minimalTime)
: count(count), minimalCount(minimalCount), time(time),
minimalTime(minimalTime) {}

InvocationCountTimeReason() {}

size_t count = 0;
size_t minimalCount = 0;
unsigned long time = 0;
unsigned long minimalTime = 0;

virtual SEXP toSEXP() const override;
virtual void fromSEXP(SEXP sexp) override;

virtual void print(std::ostream& out) const override {
this->CompileReasonImpl::print(out);

out << ", count=" << count << ", minimalCount=" << minimalCount
<< ", time=" << time << ", minimalTime=" << minimalTime;
}
};

struct PirWarmupReason : public CompileReasonImpl<PirWarmupReason, 1> {
static constexpr const char* NAME = "PirWarmupReason";
virtual ~PirWarmupReason() = default;
Expand Down
Loading