Skip to content

Commit eff4dbb

Browse files
committed
Correct universal variable shifting and substitution for nested quantifiers in recursive contexts
1 parent 62a143c commit eff4dbb

File tree

4 files changed

+410
-78
lines changed

4 files changed

+410
-78
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
open Metatypes
2+
3+
type shift_directive = { shift_amount : int; cutoff : int }
4+
5+
module IntMap = Map.Make (struct
6+
type t = int
7+
8+
let compare = compare
9+
end)
10+
11+
type context_shifts = shift_directive IntMap.t
12+
13+
(* Determines the shift directives for the recursive context, based on shifting the type as specified *)
14+
let rec get_context_shifts (directive : shift_directive) (union : union_type)
15+
(context : recursive_context) =
16+
(* Determine the initial recursive variables to shift *)
17+
let initial_shifts = get_context_shifts_union directive union in
18+
(* Determine the remaining shifts that occur by shifting the recursive variable definitions *)
19+
let final_shifts = get_context_shifts_context initial_shifts context in
20+
final_shifts
21+
22+
and get_context_shifts_union (directive : shift_directive) (union : union_type)
23+
=
24+
map_context_shifts get_context_shifts_base directive union
25+
26+
and get_context_shifts_base (directive : shift_directive)
27+
(base_type : base_type) =
28+
match base_type with
29+
(* Recursive type variables indicate that their definitions must be shifted *)
30+
| RecTypeVar num -> IntMap.singleton num directive
31+
(* Labels and universal type variables don't indicate any shifts need to be made *)
32+
| Label _ | UnivTypeVar _ -> IntMap.empty
33+
(* Intersections may have shifts internally *)
34+
| Intersection branches ->
35+
map_context_shifts get_context_shifts_func directive branches
36+
(* When we cross through a quantifier, we increment the cutoff since the index
37+
of free variables increases by one as we pass through the quantifier *)
38+
| UnivQuantification inner_type ->
39+
get_context_shifts_union
40+
{ directive with cutoff = directive.cutoff + 1 }
41+
inner_type
42+
43+
and get_context_shifts_func (directive : shift_directive)
44+
((arg, return) : unary_function) =
45+
let arg_shifts = get_context_shifts_union directive arg in
46+
let return_shifts = get_context_shifts_union directive return in
47+
join_context_shift_binary arg_shifts return_shifts
48+
49+
(* Determines the context shifts that should occur as a result of of the initial shifts, including the initial shifts *)
50+
and get_context_shifts_context (initial_shifts : context_shifts)
51+
(context : recursive_context) =
52+
get_context_shifts_context_rec initial_shifts initial_shifts context
53+
54+
(* Determines the context shifts that come from a set of new shifts *)
55+
and get_context_shifts_context_rec (acc_shifts : context_shifts)
56+
(new_shifts : context_shifts) (context : recursive_context) : context_shifts
57+
=
58+
(* Base case: when we no longer have new shifts to perform, we use the shifts we've accumulated so far *)
59+
if IntMap.is_empty new_shifts then acc_shifts
60+
else
61+
(* Determine the shift for each recursive defintion that has a directive *)
62+
let resulting_shifts =
63+
join_context_shifts
64+
(List.map
65+
(fun (num, directive) ->
66+
get_context_shifts_rec_def (List.nth context num) directive)
67+
(IntMap.to_list new_shifts))
68+
in
69+
(* Determine the new shifts that weren't already part of the accumulated shifts *)
70+
let updated_new_shifts = resolve_new_shifts acc_shifts resulting_shifts in
71+
(* Determine the new set of accumulated shifts *)
72+
let updated_acc_shifts =
73+
join_context_shift_binary acc_shifts updated_new_shifts
74+
in
75+
(* Recursively call to get the rest of the shifts taht results from the new shifts *)
76+
get_context_shifts_context_rec updated_acc_shifts updated_new_shifts context
77+
78+
and get_context_shifts_rec_def ({ flat_union; _ } : recursive_def)
79+
(directive : shift_directive) =
80+
get_context_shifts_flat_union directive flat_union
81+
82+
and get_context_shifts_flat_union (directive : shift_directive)
83+
(flat_union : flat_union_type) =
84+
map_context_shifts get_context_shifts_flat_base directive flat_union
85+
86+
and get_context_shifts_flat_base (directive : shift_directive)
87+
(flat_base : flat_base_type) =
88+
match flat_base with
89+
(* Labels and type variables do not indicate shifts need to happen *)
90+
| FLabel _ | FUnivTypeVar _ -> IntMap.empty
91+
(* Intersection shifts are determined recursively *)
92+
| FIntersection branches ->
93+
map_context_shifts get_context_shifts_func directive branches
94+
(* When we cross through a quantifier, we increment the cutoff since the index
95+
of free variables increases by one as we pass through the quantifier *)
96+
| FUnivQuantification inner_type ->
97+
get_context_shifts_union
98+
{ directive with cutoff = directive.cutoff + 1 }
99+
inner_type
100+
101+
(* Determines the shifts that are in new_shifts that aren't already in acc_shifts.
102+
Assumes that shifts to variables will be the same if repeated. Behavior is indeterminate otherwise *)
103+
and resolve_new_shifts (acc_shifts : context_shifts)
104+
(new_shifts : context_shifts) =
105+
IntMap.merge
106+
(fun _ acc_value new_value ->
107+
match (acc_value, new_value) with
108+
(* Only pull out values for keys that are in the new_shifts, but not the acc_shifts *)
109+
| None, Some _ -> new_value
110+
| _ -> None)
111+
acc_shifts new_shifts
112+
113+
(* Applies a function that maps a shift directive and a type to the context
114+
shifts for that type, across a list of types *)
115+
and map_context_shifts :
116+
'a.
117+
(shift_directive -> 'a -> context_shifts) ->
118+
shift_directive ->
119+
'a list ->
120+
context_shifts =
121+
fun func directive list ->
122+
let shifts = List.map (func directive) list in
123+
join_context_shifts shifts
124+
125+
(* Transforms a list of context shift maps into a single context shift map.
126+
Assumes maps context shifts contain identical shifts for a variable. Recursive variables
127+
that simultaneously have different shift directives have indeterminate behavior *)
128+
and join_context_shifts (context_shifts : context_shifts list) : context_shifts
129+
=
130+
List.fold_left join_context_shift_binary IntMap.empty context_shifts
131+
132+
and join_context_shift_binary (context_shift_a : context_shifts)
133+
(context_shift_b : context_shifts) : context_shifts =
134+
IntMap.merge
135+
(fun _ left_val right_val ->
136+
match (left_val, right_val) with
137+
(* NOTE: we assume that shifts in both maps are identical. Behavior is indeterminate otherwise *)
138+
| Some x, Some _ -> Some x
139+
| None, y -> y
140+
| x, None -> x)
141+
context_shift_a context_shift_b
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
open Metatypes
2+
open ShiftUnivVar
3+
4+
type substitute_directive = { variable_num : int; with_type : structured_type }
5+
6+
module IntMap = Map.Make (struct
7+
type t = int
8+
9+
let compare = compare
10+
end)
11+
12+
type context_subs = substitute_directive IntMap.t
13+
14+
(* Determines the substitution directives for the recursive context, based on
15+
the substitutions that are applied to the union type *)
16+
let rec get_context_substitutions (directive : substitute_directive)
17+
(union : union_type) (context : recursive_context) =
18+
(* Determine the initial recursive variables to substitute *)
19+
let initial_subs = get_context_subs_union directive union in
20+
(* Determine the remaining substitutions that occur by shifting the recursive
21+
variable definitions *)
22+
let final_subs = get_context_subs_context initial_subs context in
23+
final_subs
24+
25+
and get_context_subs_union (directive : substitute_directive)
26+
(union : union_type) =
27+
map_context_subs get_context_subs_base directive union
28+
29+
and get_context_subs_base (directive : substitute_directive)
30+
(base_type : base_type) =
31+
match base_type with
32+
| RecTypeVar num -> IntMap.singleton num directive
33+
| Label _ | UnivTypeVar _ -> IntMap.empty
34+
| Intersection branches ->
35+
map_context_subs get_context_subs_func directive branches
36+
| UnivQuantification inner_type ->
37+
let new_var_num = directive.variable_num + 1 in
38+
let new_with_type = shift_univ_var_type directive.with_type 1 in
39+
get_context_subs_union
40+
{ variable_num = new_var_num; with_type = new_with_type }
41+
inner_type
42+
43+
and get_context_subs_func (directive : substitute_directive)
44+
((arg, return) : unary_function) =
45+
let arg_subs = get_context_subs_union directive arg in
46+
let return_subs = get_context_subs_union directive return in
47+
join_context_sub_binary arg_subs return_subs
48+
49+
(* Determines the context substitution that should occur as a result of the
50+
initial substitutions *)
51+
and get_context_subs_context (initial_subs : context_subs)
52+
(context : recursive_context) =
53+
get_context_subs_context_rec initial_subs initial_subs context
54+
55+
and get_context_subs_context_rec (acc_subs : context_subs)
56+
(new_subs : context_subs) (context : recursive_context) : context_subs =
57+
if IntMap.is_empty new_subs then acc_subs
58+
else
59+
let resulting_subs =
60+
join_context_subs
61+
(List.map
62+
(fun (num, directive) ->
63+
get_context_subs_rec_def (List.nth context num) directive)
64+
(IntMap.to_list new_subs))
65+
in
66+
let updated_new_subs = resolve_new_subs acc_subs resulting_subs in
67+
let updated_acc_subs = join_context_sub_binary acc_subs updated_new_subs in
68+
get_context_subs_context_rec updated_acc_subs updated_new_subs context
69+
70+
and get_context_subs_rec_def ({ flat_union; _ } : recursive_def)
71+
(directive : substitute_directive) =
72+
get_context_subs_flat_union directive flat_union
73+
74+
and get_context_subs_flat_union (directive : substitute_directive)
75+
(flat_union : flat_union_type) =
76+
map_context_subs get_context_subs_flat_base directive flat_union
77+
78+
and get_context_subs_flat_base (directive : substitute_directive)
79+
(flat_base : flat_base_type) =
80+
match flat_base with
81+
| FLabel _ | FUnivTypeVar _ -> IntMap.empty
82+
| FIntersection branches ->
83+
map_context_subs get_context_subs_func directive branches
84+
| FUnivQuantification inner_type ->
85+
let new_var_num = directive.variable_num + 1 in
86+
let new_with_type = shift_univ_var_type directive.with_type 1 in
87+
get_context_subs_union
88+
{ variable_num = new_var_num; with_type = new_with_type }
89+
inner_type
90+
91+
and resolve_new_subs (acc_subs : context_subs) (new_subs : context_subs) =
92+
IntMap.merge
93+
(fun _ acc_value new_value ->
94+
match (acc_value, new_value) with None, Some _ -> new_value | _ -> None)
95+
acc_subs new_subs
96+
97+
and map_context_subs :
98+
'a.
99+
(substitute_directive -> 'a -> context_subs) ->
100+
substitute_directive ->
101+
'a list ->
102+
context_subs =
103+
fun func directive list ->
104+
let substitutions = List.map (func directive) list in
105+
join_context_subs substitutions
106+
107+
and join_context_subs (context_subs : context_subs list) : context_subs =
108+
List.fold_left join_context_sub_binary IntMap.empty context_subs
109+
110+
and join_context_sub_binary (context_sub_a : context_subs)
111+
(context_sub_b : context_subs) : context_subs =
112+
IntMap.merge
113+
(fun _ left_val right_val ->
114+
match (left_val, right_val) with
115+
| Some x, Some _ -> Some x
116+
| None, y -> y
117+
| x, None -> x)
118+
context_sub_a context_sub_b

src/structured/termOperations/shiftUnivVar.ml

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ open Metatypes
22
open TermTypes
33
open Common.Helpers
44
open TypeOperations.Create
5+
open GetContextShifts
56

67
(** Utilities for shifting universal quantification variables represented by de Bruijn indices *)
78

@@ -51,24 +52,37 @@ and shift_univ_var_term_rec (shift_amount : int) (cutoff : int) (term : term) =
5152

5253
and shift_univ_var_type_rec (shift_amount : int) (cutoff : int)
5354
(stype : structured_type) =
54-
(* Shift the recursive context and union separately, then combine *)
55-
let shifted_context =
56-
shift_univ_var_context shift_amount cutoff stype.context
57-
in
55+
(* Shift the universal type variables in the union type *)
5856
let shifted_union = shift_univ_var_union shift_amount cutoff stype.union in
57+
(* Determine the shifts that need to be made to the recursive context *)
58+
let context_shifts =
59+
get_context_shifts { shift_amount; cutoff } stype.union stype.context
60+
in
61+
(* Shift the context accordingly *)
62+
let shifted_context = shift_univ_var_context stype.context context_shifts in
5963
let shifted_type = build_structured_type shifted_union shifted_context in
6064
shifted_type
6165

62-
and shift_univ_var_context (shift_amount : int) (cutoff : int)
63-
(context : recursive_context) =
64-
List.map (shift_univ_var_context_def shift_amount cutoff) context
66+
and shift_univ_var_context (context : recursive_context)
67+
(context_shifts : context_shifts) =
68+
(* Shift each recursive definition with the appropriate directive (if any) *)
69+
(* TODO: shift based on the shift directives, rather than for each element in the context to simplify *)
70+
List.mapi
71+
(fun idx context_def ->
72+
shift_univ_var_context_def context_def
73+
(IntMap.find_opt idx context_shifts))
74+
context
6575

66-
and shift_univ_var_context_def (shift_amount : int) (cutoff : int)
67-
({ kind; flat_union } : recursive_def) =
68-
{
69-
kind;
70-
flat_union = shift_univ_var_flat_union shift_amount cutoff flat_union;
71-
}
76+
and shift_univ_var_context_def ({ kind; flat_union } : recursive_def)
77+
(shift_directive_opt : shift_directive option) =
78+
(* If the shift directive is None, then skip, otherwise shift the definition according to the directive *)
79+
if Option.is_none shift_directive_opt then { kind; flat_union }
80+
else
81+
let { shift_amount; cutoff } = Option.get shift_directive_opt in
82+
{
83+
kind;
84+
flat_union = shift_univ_var_flat_union shift_amount cutoff flat_union;
85+
}
7286

7387
and shift_univ_var_flat_union (shift_amount : int) (cutoff : int)
7488
(flat_union : flat_union_type) =
@@ -102,7 +116,7 @@ and shift_univ_var_union (shift_amount : int) (cutoff : int)
102116

103117
and shift_univ_var_base (shift_amount : int) (cutoff : int) (base : base_type) =
104118
match base with
105-
(* Labels and recursive type variables don't need to be shifted (context is shifted separately) *)
119+
(* Labels and recursive type variables don't need to be shifted (recursive shifting happens in other step) *)
106120
| Label _ | RecTypeVar _ -> base
107121
(* Intersections are shifted recursively *)
108122
| Intersection branches ->

0 commit comments

Comments
 (0)