Skip to content

Commit

Permalink
small optimization to subtyping (#41672)
Browse files Browse the repository at this point in the history
Zero and copy only the used portion of the union state buffer.

(cherry picked from commit 0258553)
  • Loading branch information
JeffBezanson authored and KristofferC committed Aug 25, 2021
1 parent 83bf082 commit d71bd32
Showing 1 changed file with 65 additions and 36 deletions.
101 changes: 65 additions & 36 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@ extern "C" {
// TODO: the stack probably needs to be artificially large because of some
// deeper problem (see #21191) and could be shrunk once that is fixed
typedef struct {
int depth;
int more;
int16_t depth;
int16_t more;
int16_t used;
uint32_t stack[100]; // stack of bits represented as a bit vector
} jl_unionstate_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
void *stack;
} jl_saved_unionstate_t;

// Linked list storing the type variable environment. A new jl_varbinding_t
// is pushed for each UnionAll type we encounter. `lb` and `ub` are updated
// during the computation.
Expand All @@ -68,14 +76,14 @@ typedef struct jl_varbinding_t {
// and we would need to return `intersect(var,other)`. in this case
// we choose to over-estimate the intersection by returning the var.
int8_t constraintkind;
int depth0; // # of invariant constructors nested around the UnionAll type for this var
int8_t intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
// when this variable's integer value is compared to that of another,
// it equals `other + offset`. used by vararg length parameters.
int offset;
int16_t offset;
// array of typevars that our bounds depend on, whose UnionAlls need to be
// moved outside ours.
jl_array_t *innervars;
int intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
struct jl_varbinding_t *prev;
} jl_varbinding_t;

Expand Down Expand Up @@ -129,6 +137,23 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
st->stack[i>>5] &= ~(1u<<(i&31));
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
(saved)->stack = alloca(((src)->used+7)/8); \
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
} while (0);

#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
} while (0);

typedef struct {
int8_t *buf;
int rdepth;
Expand Down Expand Up @@ -486,6 +511,10 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
Expand Down Expand Up @@ -514,20 +543,19 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
return 1;
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_unionstate_t oldLunions = e->Lunions;
jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int sub;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

e->Runions = oldRunions;
e->Lunions = oldLunions;
pop_unionstate(&e->Runions, &oldRunions);
pop_unionstate(&e->Lunions, &oldLunions);
return sub;
}

Expand Down Expand Up @@ -731,8 +759,8 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
int ans;
Expand Down Expand Up @@ -1148,6 +1176,10 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
Expand Down Expand Up @@ -1310,21 +1342,21 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
return 0;

jl_unionstate_t oldLunions = e->Lunions;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
jl_unionstate_t oldRunions = e->Runions;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
}
else {
int lastset = 0;
Expand All @@ -1342,13 +1374,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
}

e->Lunions = oldLunions;
pop_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
Expand Down Expand Up @@ -1379,7 +1411,7 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
JL_GC_PUSH1(&saved);
save_env(e, &saved, &se);

memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
e->Lunions.used = 0;
int lastset = 0;
int sub;
while (1) {
Expand Down Expand Up @@ -1415,6 +1447,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->emptiness_only = 0;
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
e->Lunions.used = 0; e->Runions.used = 0;
}

// subtyping entry points
Expand Down Expand Up @@ -2084,14 +2117,14 @@ static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e,
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;

jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

jl_value_t *res = intersect_all(x, y, e);

e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
return res;
Expand All @@ -2102,10 +2135,10 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
jl_value_t *a=NULL, *b=NULL;
JL_GC_PUSH2(&a, &b);
jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
jl_value_t *i = simple_join(a,b);
JL_GC_POP();
return i;
Expand Down Expand Up @@ -2600,8 +2633,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
Expand Down Expand Up @@ -3159,7 +3192,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
e->Runions.depth = 0;
e->Runions.more = 0;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
jl_value_t **is;
JL_GC_PUSHARGS(is, 3);
jl_value_t **saved = &is[2];
Expand All @@ -3176,11 +3209,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
save_env(e, saved, &se);
}
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type) {
free_env(&se);
JL_GC_POP();
return ii;
}
if (e->emptiness_only && ii != jl_bottom_type)
break;
e->Runions.depth = 0;
int set = e->Runions.more - 1;
e->Runions.more = 0;
Expand Down Expand Up @@ -3209,9 +3239,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
total_iter++;
if (niter > 3 || total_iter > 400000) {
free_env(&se);
JL_GC_POP();
return y;
ii = y;
break;
}
}
free_env(&se);
Expand Down

0 comments on commit d71bd32

Please sign in to comment.