From d71bd325e69826e70ef246db12b8bffd32de89c2 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Thu, 19 Aug 2021 12:41:32 -0400 Subject: [PATCH] small optimization to subtyping (#41672) Zero and copy only the used portion of the union state buffer. (cherry picked from commit 0258553a82aba0a609978d1719e05a20ebdf4826) --- src/subtype.c | 101 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 36 deletions(-) diff --git a/src/subtype.c b/src/subtype.c index 158a9dd70b3f3..152d17daeaaaa 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -42,11 +42,19 @@ extern "C" { // TODO: the stack probably needs to be artificially large because of some // deeper problem (see #21191) and could be shrunk once that is fixed typedef struct { - int depth; - int more; + int16_t depth; + int16_t more; + int16_t used; uint32_t stack[100]; // stack of bits represented as a bit vector } jl_unionstate_t; +typedef struct { + int16_t depth; + int16_t more; + int16_t used; + void *stack; +} jl_saved_unionstate_t; + // Linked list storing the type variable environment. A new jl_varbinding_t // is pushed for each UnionAll type we encounter. `lb` and `ub` are updated // during the computation. @@ -68,14 +76,14 @@ typedef struct jl_varbinding_t { // and we would need to return `intersect(var,other)`. in this case // we choose to over-estimate the intersection by returning the var. int8_t constraintkind; - int depth0; // # of invariant constructors nested around the UnionAll type for this var + int8_t intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N} + int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var // when this variable's integer value is compared to that of another, // it equals `other + offset`. used by vararg length parameters. - int offset; + int16_t offset; // array of typevars that our bounds depend on, whose UnionAlls need to be // moved outside ours. jl_array_t *innervars; - int intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N} struct jl_varbinding_t *prev; } jl_varbinding_t; @@ -129,6 +137,23 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT st->stack[i>>5] &= ~(1u<<(i&31)); } +#define push_unionstate(saved, src) \ + do { \ + (saved)->depth = (src)->depth; \ + (saved)->more = (src)->more; \ + (saved)->used = (src)->used; \ + (saved)->stack = alloca(((src)->used+7)/8); \ + memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \ + } while (0); + +#define pop_unionstate(dst, saved) \ + do { \ + (dst)->depth = (saved)->depth; \ + (dst)->more = (saved)->more; \ + (dst)->used = (saved)->used; \ + memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \ + } while (0); + typedef struct { int8_t *buf; int rdepth; @@ -486,6 +511,10 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv { jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; do { + if (state->depth >= state->used) { + statestack_set(state, state->used, 0); + state->used++; + } int ui = statestack_get(state, state->depth); state->depth++; if (ui == 0) { @@ -514,11 +543,10 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) return 1; if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y)) return 0; - jl_unionstate_t oldLunions = e->Lunions; - jl_unionstate_t oldRunions = e->Runions; + jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions); + jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions); int sub; - memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack)); - memset(e->Runions.stack, 0, sizeof(e->Runions.stack)); + e->Lunions.used = e->Runions.used = 0; e->Runions.depth = 0; e->Runions.more = 0; e->Lunions.depth = 0; @@ -526,8 +554,8 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) sub = forall_exists_subtype(x, y, e, 0); - e->Runions = oldRunions; - e->Lunions = oldLunions; + pop_unionstate(&e->Runions, &oldRunions); + pop_unionstate(&e->Lunions, &oldLunions); return sub; } @@ -731,8 +759,8 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e) static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param) { u = unalias_unionall(u, e); - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, - R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars }; + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, + R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars }; JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars); e->vars = &vb; int ans; @@ -1148,6 +1176,10 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param) // union against the variable before trying to take it apart to see if there are any // variables lurking inside. jl_unionstate_t *state = &e->Runions; + if (state->depth >= state->used) { + statestack_set(state, state->used, 0); + state->used++; + } ui = statestack_get(state, state->depth); state->depth++; if (ui == 0) @@ -1310,13 +1342,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) (is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y))) return 0; - jl_unionstate_t oldLunions = e->Lunions; - memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack)); + jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions); + e->Lunions.used = 0; int sub; if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) { - jl_unionstate_t oldRunions = e->Runions; - memset(e->Runions.stack, 0, sizeof(e->Runions.stack)); + jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions); + e->Runions.used = 0; e->Runions.depth = 0; e->Runions.more = 0; e->Lunions.depth = 0; @@ -1324,7 +1356,7 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) sub = forall_exists_subtype(x, y, e, 2); - e->Runions = oldRunions; + pop_unionstate(&e->Runions, &oldRunions); } else { int lastset = 0; @@ -1342,13 +1374,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) } } - e->Lunions = oldLunions; + pop_unionstate(&e->Lunions, &oldLunions); return sub && subtype(y, x, e, 0); } static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param) { - memset(e->Runions.stack, 0, sizeof(e->Runions.stack)); + e->Runions.used = 0; int lastset = 0; while (1) { e->Runions.depth = 0; @@ -1379,7 +1411,7 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in JL_GC_PUSH1(&saved); save_env(e, &saved, &se); - memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack)); + e->Lunions.used = 0; int lastset = 0; int sub; while (1) { @@ -1415,6 +1447,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz) e->emptiness_only = 0; e->Lunions.depth = 0; e->Runions.depth = 0; e->Lunions.more = 0; e->Runions.more = 0; + e->Lunions.used = 0; e->Runions.used = 0; } // subtyping entry points @@ -2084,14 +2117,14 @@ static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x)) return x; - jl_unionstate_t oldRunions = e->Runions; + jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions); int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth; // TODO: this doesn't quite make sense e->invdepth = e->Rinvdepth = d; jl_value_t *res = intersect_all(x, y, e); - e->Runions = oldRunions; + pop_unionstate(&e->Runions, &oldRunions); e->invdepth = savedepth; e->Rinvdepth = Rsavedepth; return res; @@ -2102,10 +2135,10 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) { jl_value_t *a=NULL, *b=NULL; JL_GC_PUSH2(&a, &b); - jl_unionstate_t oldRunions = e->Runions; + jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions); a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e); b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e); - e->Runions = oldRunions; + pop_unionstate(&e->Runions, &oldRunions); jl_value_t *i = simple_join(a,b); JL_GC_POP(); return i; @@ -2600,8 +2633,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ { jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL; jl_savedenv_t se, se2; - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, - R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars }; + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, + R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars }; JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars); save_env(e, &save, &se); res = intersect_unionall_(t, u, e, R, param, &vb); @@ -3159,7 +3192,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) { e->Runions.depth = 0; e->Runions.more = 0; - memset(e->Runions.stack, 0, sizeof(e->Runions.stack)); + e->Runions.used = 0; jl_value_t **is; JL_GC_PUSHARGS(is, 3); jl_value_t **saved = &is[2]; @@ -3176,11 +3209,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) save_env(e, saved, &se); } while (e->Runions.more) { - if (e->emptiness_only && ii != jl_bottom_type) { - free_env(&se); - JL_GC_POP(); - return ii; - } + if (e->emptiness_only && ii != jl_bottom_type) + break; e->Runions.depth = 0; int set = e->Runions.more - 1; e->Runions.more = 0; @@ -3209,9 +3239,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) } total_iter++; if (niter > 3 || total_iter > 400000) { - free_env(&se); - JL_GC_POP(); - return y; + ii = y; + break; } } free_env(&se);