Skip to content

Commit

Permalink
small optimization to subtyping
Browse files Browse the repository at this point in the history
Zero and copy only the used portion of the union state buffer.
  • Loading branch information
JeffBezanson committed Aug 11, 2021
1 parent 34dc044 commit a9ac293
Showing 1 changed file with 65 additions and 36 deletions.
101 changes: 65 additions & 36 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@ extern "C" {
// TODO: the stack probably needs to be artificially large because of some
// deeper problem (see #21191) and could be shrunk once that is fixed
typedef struct {
int depth;
int more;
int16_t depth;
int16_t more;
int16_t used;
uint32_t stack[100]; // stack of bits represented as a bit vector
} jl_unionstate_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
void *stack;
} jl_saved_unionstate_t;

// Linked list storing the type variable environment. A new jl_varbinding_t
// is pushed for each UnionAll type we encounter. `lb` and `ub` are updated
// during the computation.
Expand All @@ -68,14 +76,14 @@ typedef struct jl_varbinding_t {
// and we would need to return `intersect(var,other)`. in this case
// we choose to over-estimate the intersection by returning the var.
int8_t constraintkind;
int depth0; // # of invariant constructors nested around the UnionAll type for this var
int8_t intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
// when this variable's integer value is compared to that of another,
// it equals `other + offset`. used by vararg length parameters.
int offset;
int16_t offset;
// array of typevars that our bounds depend on, whose UnionAlls need to be
// moved outside ours.
jl_array_t *innervars;
int intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
struct jl_varbinding_t *prev;
} jl_varbinding_t;

Expand Down Expand Up @@ -129,6 +137,23 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
st->stack[i>>5] &= ~(1u<<(i&31));
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
(saved)->stack = alloca(((src)->used+7)/8); \
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
} while (0);

#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
} while (0);

typedef struct {
int8_t *buf;
int rdepth;
Expand Down Expand Up @@ -486,6 +511,10 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
Expand Down Expand Up @@ -514,20 +543,19 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
return 1;
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_unionstate_t oldLunions = e->Lunions;
jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int sub;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

e->Runions = oldRunions;
e->Lunions = oldLunions;
pop_unionstate(&e->Runions, &oldRunions);
pop_unionstate(&e->Lunions, &oldLunions);
return sub;
}

Expand Down Expand Up @@ -731,8 +759,8 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
int ans;
Expand Down Expand Up @@ -1148,6 +1176,10 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
Expand Down Expand Up @@ -1310,21 +1342,21 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
return 0;

jl_unionstate_t oldLunions = e->Lunions;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
jl_unionstate_t oldRunions = e->Runions;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
}
else {
int lastset = 0;
Expand All @@ -1342,13 +1374,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
}

e->Lunions = oldLunions;
pop_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
Expand Down Expand Up @@ -1379,7 +1411,7 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
JL_GC_PUSH1(&saved);
save_env(e, &saved, &se);

memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
e->Lunions.used = 0;
int lastset = 0;
int sub;
while (1) {
Expand Down Expand Up @@ -1415,6 +1447,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->emptiness_only = 0;
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
e->Lunions.used = 0; e->Runions.used = 0;
}

// subtyping entry points
Expand Down Expand Up @@ -2084,14 +2117,14 @@ static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e,
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;

jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

jl_value_t *res = intersect_all(x, y, e);

e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
return res;
Expand All @@ -2102,10 +2135,10 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
jl_value_t *a=NULL, *b=NULL;
JL_GC_PUSH2(&a, &b);
jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
jl_value_t *i = simple_join(a,b);
JL_GC_POP();
return i;
Expand Down Expand Up @@ -2600,8 +2633,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
Expand Down Expand Up @@ -3159,7 +3192,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
e->Runions.depth = 0;
e->Runions.more = 0;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
jl_value_t **is;
JL_GC_PUSHARGS(is, 3);
jl_value_t **saved = &is[2];
Expand All @@ -3176,11 +3209,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
save_env(e, saved, &se);
}
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type) {
free_env(&se);
JL_GC_POP();
return ii;
}
if (e->emptiness_only && ii != jl_bottom_type)
break;
e->Runions.depth = 0;
int set = e->Runions.more - 1;
e->Runions.more = 0;
Expand Down Expand Up @@ -3209,9 +3239,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
total_iter++;
if (niter > 3 || total_iter > 400000) {
free_env(&se);
JL_GC_POP();
return y;
ii = y;
break;
}
}
free_env(&se);
Expand Down

0 comments on commit a9ac293

Please sign in to comment.