Skip to content

Commit aee69a4

Browse files
committed
subtype: make union stack size scalable.
1 parent 0c8c20c commit aee69a4

File tree

1 file changed

+121
-61
lines changed

1 file changed

+121
-61
lines changed

src/subtype.c

Lines changed: 121 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,24 @@ extern "C" {
3939
// Union type decision points are discovered while the algorithm works.
4040
// If a new Union decision is encountered, the `more` flag is set to tell
4141
// the forall/exists loop to grow the stack.
42-
// TODO: the stack probably needs to be artificially large because of some
43-
// deeper problem (see #21191) and could be shrunk once that is fixed
42+
43+
typedef struct jl_bits_stack_t {
44+
uint32_t data[16];
45+
struct jl_bits_stack_t *next;
46+
} jl_bits_stack_t;
47+
4448
typedef struct {
4549
int16_t depth;
4650
int16_t more;
4751
int16_t used;
48-
uint32_t stack[100]; // stack of bits represented as a bit vector
52+
jl_bits_stack_t stack;
4953
} jl_unionstate_t;
5054

5155
typedef struct {
5256
int16_t depth;
5357
int16_t more;
5458
int16_t used;
55-
void *stack;
59+
uint8_t *stack;
5660
} jl_saved_unionstate_t;
5761

5862
// Linked list storing the type variable environment. A new jl_varbinding_t
@@ -131,37 +135,111 @@ static jl_varbinding_t *lookup(jl_stenv_t *e, jl_tvar_t *v) JL_GLOBALLY_ROOTED J
131135
}
132136
#endif
133137

138+
// union-stack tools
139+
134140
static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT
135141
{
136-
assert(i >= 0 && i < sizeof(st->stack) * 8);
142+
assert(i >= 0 && i <= 32767); // limited by the depth bit.
137143
// get the `i`th bit in an array of 32-bit words
138-
return (st->stack[i>>5] & (1u<<(i&31))) != 0;
144+
jl_bits_stack_t *stack = &st->stack;
145+
while (i >= sizeof(stack->data) * 8) {
146+
// We should have set this bit.
147+
assert(stack->next);
148+
stack = stack->next;
149+
i -= sizeof(stack->data) * 8;
150+
}
151+
return (stack->data[i>>5] & (1u<<(i&31))) != 0;
139152
}
140153

141154
static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
142155
{
143-
assert(i >= 0 && i < sizeof(st->stack) * 8);
156+
assert(i >= 0 && i <= 32767); // limited by the depth bit.
157+
jl_bits_stack_t *stack = &st->stack;
158+
while (i >= sizeof(stack->data) * 8) {
159+
if (__unlikely(stack->next == NULL)) {
160+
stack->next = (jl_bits_stack_t *)malloc(sizeof(jl_bits_stack_t));
161+
stack->next->next = NULL;
162+
}
163+
stack = stack->next;
164+
i -= sizeof(stack->data) * 8;
165+
}
144166
if (val)
145-
st->stack[i>>5] |= (1u<<(i&31));
167+
stack->data[i>>5] |= (1u<<(i&31));
146168
else
147-
st->stack[i>>5] &= ~(1u<<(i&31));
169+
stack->data[i>>5] &= ~(1u<<(i&31));
148170
}
149171

150-
#define push_unionstate(saved, src) \
151-
do { \
152-
(saved)->depth = (src)->depth; \
153-
(saved)->more = (src)->more; \
154-
(saved)->used = (src)->used; \
155-
(saved)->stack = alloca(((src)->used+7)/8); \
156-
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
172+
#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)
173+
174+
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
175+
{
176+
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
177+
if (state->more == 0)
178+
return 0;
179+
// reset `used` and let `pick_union_decision` clean the stack.
180+
state->used = state->more;
181+
statestack_set(state, state->used - 1, 1);
182+
return 1;
183+
}
184+
185+
static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
186+
{
187+
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
188+
if (state->depth >= state->used) {
189+
statestack_set(state, state->used, 0);
190+
state->used++;
191+
}
192+
int ui = statestack_get(state, state->depth);
193+
state->depth++;
194+
if (ui == 0)
195+
state->more = state->depth; // memorize that this was the deepest available choice
196+
return ui;
197+
}
198+
199+
static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
200+
{
201+
do {
202+
if (pick_union_decision(e, R))
203+
u = ((jl_uniontype_t*)u)->b;
204+
else
205+
u = ((jl_uniontype_t*)u)->a;
206+
} while (jl_is_uniontype(u));
207+
return u;
208+
}
209+
210+
#define push_unionstate(saved, src) \
211+
do { \
212+
(saved)->depth = (src)->depth; \
213+
(saved)->more = (src)->more; \
214+
(saved)->used = (src)->used; \
215+
jl_bits_stack_t *srcstack = &(src)->stack; \
216+
int pushbits = ((saved)->used+7)/8; \
217+
(saved)->stack = (uint8_t *)alloca(pushbits); \
218+
for (int n = 0; n < pushbits; n += sizeof(srcstack->data)) { \
219+
assert(srcstack != NULL); \
220+
int rest = pushbits - n; \
221+
if (rest > sizeof(srcstack->data)) \
222+
rest = sizeof(srcstack->data); \
223+
memcpy(&(saved)->stack[n], &srcstack->data, rest); \
224+
srcstack = srcstack->next; \
225+
} \
157226
} while (0);
158227

159-
#define pop_unionstate(dst, saved) \
160-
do { \
161-
(dst)->depth = (saved)->depth; \
162-
(dst)->more = (saved)->more; \
163-
(dst)->used = (saved)->used; \
164-
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
228+
#define pop_unionstate(dst, saved) \
229+
do { \
230+
(dst)->depth = (saved)->depth; \
231+
(dst)->more = (saved)->more; \
232+
(dst)->used = (saved)->used; \
233+
jl_bits_stack_t *dststack = &(dst)->stack; \
234+
int popbits = ((saved)->used+7)/8; \
235+
for (int n = 0; n < popbits; n += sizeof(dststack->data)) { \
236+
assert(dststack != NULL); \
237+
int rest = popbits - n; \
238+
if (rest > sizeof(dststack->data)) \
239+
rest = sizeof(dststack->data); \
240+
memcpy(&dststack->data, &(saved)->stack[n], rest); \
241+
dststack = dststack->next; \
242+
} \
165243
} while (0);
166244

167245
static int current_env_length(jl_stenv_t *e)
@@ -264,6 +342,18 @@ static void free_env(jl_savedenv_t *se) JL_NOTSAFEPOINT
264342
se->buf = NULL;
265343
}
266344

345+
static void free_stenv(jl_stenv_t *e) JL_NOTSAFEPOINT
346+
{
347+
for (int R = 0; R < 2; R++) {
348+
jl_bits_stack_t *temp = R ? e->Runions.stack.next : e->Lunions.stack.next;
349+
while (temp != NULL) {
350+
jl_bits_stack_t *next = temp->next;
351+
free(temp);
352+
temp = next;
353+
}
354+
}
355+
}
356+
267357
static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPOINT
268358
{
269359
jl_value_t **roots = NULL;
@@ -587,44 +677,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
587677

588678
static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
589679

590-
#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)
591-
592-
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
593-
{
594-
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
595-
if (state->more == 0)
596-
return 0;
597-
// reset `used` and let `pick_union_decision` clean the stack.
598-
state->used = state->more;
599-
statestack_set(state, state->used - 1, 1);
600-
return 1;
601-
}
602-
603-
static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
604-
{
605-
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
606-
if (state->depth >= state->used) {
607-
statestack_set(state, state->used, 0);
608-
state->used++;
609-
}
610-
int ui = statestack_get(state, state->depth);
611-
state->depth++;
612-
if (ui == 0)
613-
state->more = state->depth; // memorize that this was the deepest available choice
614-
return ui;
615-
}
616-
617-
static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
618-
{
619-
do {
620-
if (pick_union_decision(e, R))
621-
u = ((jl_uniontype_t*)u)->b;
622-
else
623-
u = ((jl_uniontype_t*)u)->a;
624-
} while (jl_is_uniontype(u));
625-
return u;
626-
}
627-
628680
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);
629681

630682
// subtype for variable bounds consistency check. needs its own forall/exists environment.
@@ -1728,6 +1780,8 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
17281780
e->Lunions.depth = 0; e->Runions.depth = 0;
17291781
e->Lunions.more = 0; e->Runions.more = 0;
17301782
e->Lunions.used = 0; e->Runions.used = 0;
1783+
e->Lunions.stack.next = NULL;
1784+
e->Runions.stack.next = NULL;
17311785
}
17321786

17331787
// subtyping entry points
@@ -2157,6 +2211,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env,
21572211
}
21582212
init_stenv(&e, env, envsz);
21592213
int subtype = forall_exists_subtype(x, y, &e, 0);
2214+
free_stenv(&e);
21602215
assert(obvious_subtype == 3 || obvious_subtype == subtype || jl_has_free_typevars(x) || jl_has_free_typevars(y));
21612216
#ifndef NDEBUG
21622217
if (obvious_subtype == 0 || (obvious_subtype == 1 && envsz == 0))
@@ -2249,6 +2304,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
22492304
{
22502305
init_stenv(&e, NULL, 0);
22512306
int subtype = forall_exists_subtype(a, b, &e, 0);
2307+
free_stenv(&e);
22522308
assert(subtype_ab == 3 || subtype_ab == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b));
22532309
#ifndef NDEBUG
22542310
if (subtype_ab != 0 && subtype_ab != 1) // ensures that running in a debugger doesn't change the result
@@ -2265,6 +2321,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
22652321
{
22662322
init_stenv(&e, NULL, 0);
22672323
int subtype = forall_exists_subtype(b, a, &e, 0);
2324+
free_stenv(&e);
22682325
assert(subtype_ba == 3 || subtype_ba == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b));
22692326
#ifndef NDEBUG
22702327
if (subtype_ba != 0 && subtype_ba != 1) // ensures that running in a debugger doesn't change the result
@@ -4230,7 +4287,9 @@ static jl_value_t *intersect_types(jl_value_t *x, jl_value_t *y, int emptiness_o
42304287
init_stenv(&e, NULL, 0);
42314288
e.intersection = e.ignore_free = 1;
42324289
e.emptiness_only = emptiness_only;
4233-
return intersect_all(x, y, &e);
4290+
jl_value_t *ans = intersect_all(x, y, &e);
4291+
free_stenv(&e);
4292+
return ans;
42344293
}
42354294

42364295
JL_DLLEXPORT jl_value_t *jl_intersect_types(jl_value_t *x, jl_value_t *y)
@@ -4407,6 +4466,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t *
44074466
memset(env, 0, szb*sizeof(void*));
44084467
e.envsz = szb;
44094468
*ans = intersect_all(a, b, &e);
4469+
free_stenv(&e);
44104470
if (*ans == jl_bottom_type) goto bot;
44114471
// TODO: code dealing with method signatures is not able to handle unions, so if
44124472
// `a` and `b` are both tuples, we need to be careful and may not return a union,

0 commit comments

Comments
 (0)