|
| 1 | +open Metatypes |
| 2 | + |
| 3 | +type shift_directive = { shift_amount : int; cutoff : int } |
| 4 | + |
| 5 | +module IntMap = Map.Make (struct |
| 6 | + type t = int |
| 7 | + |
| 8 | + let compare = compare |
| 9 | +end) |
| 10 | + |
| 11 | +type context_shifts = shift_directive IntMap.t |
| 12 | + |
| 13 | +(* Determines the shift directives for the recursive context, based on shifting the type as specified *) |
| 14 | +let rec get_context_shifts (directive : shift_directive) (union : union_type) |
| 15 | + (context : recursive_context) = |
| 16 | + (* Determine the initial recursive variables to shift *) |
| 17 | + let initial_shifts = get_context_shifts_union directive union in |
| 18 | + (* Determine the remaining shifts that occur by shifting the recursive variable definitions *) |
| 19 | + let final_shifts = get_context_shifts_context initial_shifts context in |
| 20 | + final_shifts |
| 21 | + |
| 22 | +and get_context_shifts_union (directive : shift_directive) (union : union_type) |
| 23 | + = |
| 24 | + map_context_shifts get_context_shifts_base directive union |
| 25 | + |
| 26 | +and get_context_shifts_base (directive : shift_directive) |
| 27 | + (base_type : base_type) = |
| 28 | + match base_type with |
| 29 | + (* Recursive type variables indicate that their definitions must be shifted *) |
| 30 | + | RecTypeVar num -> IntMap.singleton num directive |
| 31 | + (* Labels and universal type variables don't indicate any shifts need to be made *) |
| 32 | + | Label _ | UnivTypeVar _ -> IntMap.empty |
| 33 | + (* Intersections may have shifts internally *) |
| 34 | + | Intersection branches -> |
| 35 | + map_context_shifts get_context_shifts_func directive branches |
| 36 | + (* When we cross through a quantifier, we increment the cutoff since the index |
| 37 | + of free variables increases by one as we pass through the quantifier *) |
| 38 | + | UnivQuantification inner_type -> |
| 39 | + get_context_shifts_union |
| 40 | + { directive with cutoff = directive.cutoff + 1 } |
| 41 | + inner_type |
| 42 | + |
| 43 | +and get_context_shifts_func (directive : shift_directive) |
| 44 | + ((arg, return) : unary_function) = |
| 45 | + let arg_shifts = get_context_shifts_union directive arg in |
| 46 | + let return_shifts = get_context_shifts_union directive return in |
| 47 | + join_context_shift_binary arg_shifts return_shifts |
| 48 | + |
| 49 | +(* Determines the context shifts that should occur as a result of of the initial shifts, including the initial shifts *) |
| 50 | +and get_context_shifts_context (initial_shifts : context_shifts) |
| 51 | + (context : recursive_context) = |
| 52 | + get_context_shifts_context_rec initial_shifts initial_shifts context |
| 53 | + |
| 54 | +(* Determines the context shifts that come from a set of new shifts *) |
| 55 | +and get_context_shifts_context_rec (acc_shifts : context_shifts) |
| 56 | + (new_shifts : context_shifts) (context : recursive_context) : context_shifts |
| 57 | + = |
| 58 | + (* Base case: when we no longer have new shifts to perform, we use the shifts we've accumulated so far *) |
| 59 | + if IntMap.is_empty new_shifts then acc_shifts |
| 60 | + else |
| 61 | + (* Determine the shift for each recursive defintion that has a directive *) |
| 62 | + let resulting_shifts = |
| 63 | + join_context_shifts |
| 64 | + (List.map |
| 65 | + (fun (num, directive) -> |
| 66 | + get_context_shifts_rec_def (List.nth context num) directive) |
| 67 | + (IntMap.to_list new_shifts)) |
| 68 | + in |
| 69 | + (* Determine the new shifts that weren't already part of the accumulated shifts *) |
| 70 | + let updated_new_shifts = resolve_new_shifts acc_shifts resulting_shifts in |
| 71 | + (* Determine the new set of accumulated shifts *) |
| 72 | + let updated_acc_shifts = |
| 73 | + join_context_shift_binary acc_shifts updated_new_shifts |
| 74 | + in |
| 75 | + (* Recursively call to get the rest of the shifts taht results from the new shifts *) |
| 76 | + get_context_shifts_context_rec updated_acc_shifts updated_new_shifts context |
| 77 | + |
| 78 | +and get_context_shifts_rec_def ({ flat_union; _ } : recursive_def) |
| 79 | + (directive : shift_directive) = |
| 80 | + get_context_shifts_flat_union directive flat_union |
| 81 | + |
| 82 | +and get_context_shifts_flat_union (directive : shift_directive) |
| 83 | + (flat_union : flat_union_type) = |
| 84 | + map_context_shifts get_context_shifts_flat_base directive flat_union |
| 85 | + |
| 86 | +and get_context_shifts_flat_base (directive : shift_directive) |
| 87 | + (flat_base : flat_base_type) = |
| 88 | + match flat_base with |
| 89 | + (* Labels and type variables do not indicate shifts need to happen *) |
| 90 | + | FLabel _ | FUnivTypeVar _ -> IntMap.empty |
| 91 | + (* Intersection shifts are determined recursively *) |
| 92 | + | FIntersection branches -> |
| 93 | + map_context_shifts get_context_shifts_func directive branches |
| 94 | + (* When we cross through a quantifier, we increment the cutoff since the index |
| 95 | + of free variables increases by one as we pass through the quantifier *) |
| 96 | + | FUnivQuantification inner_type -> |
| 97 | + get_context_shifts_union |
| 98 | + { directive with cutoff = directive.cutoff + 1 } |
| 99 | + inner_type |
| 100 | + |
| 101 | +(* Determines the shifts that are in new_shifts that aren't already in acc_shifts. |
| 102 | + Assumes that shifts to variables will be the same if repeated. Behavior is indeterminate otherwise *) |
| 103 | +and resolve_new_shifts (acc_shifts : context_shifts) |
| 104 | + (new_shifts : context_shifts) = |
| 105 | + IntMap.merge |
| 106 | + (fun _ acc_value new_value -> |
| 107 | + match (acc_value, new_value) with |
| 108 | + (* Only pull out values for keys that are in the new_shifts, but not the acc_shifts *) |
| 109 | + | None, Some _ -> new_value |
| 110 | + | _ -> None) |
| 111 | + acc_shifts new_shifts |
| 112 | + |
| 113 | +(* Applies a function that maps a shift directive and a type to the context |
| 114 | + shifts for that type, across a list of types *) |
| 115 | +and map_context_shifts : |
| 116 | + 'a. |
| 117 | + (shift_directive -> 'a -> context_shifts) -> |
| 118 | + shift_directive -> |
| 119 | + 'a list -> |
| 120 | + context_shifts = |
| 121 | + fun func directive list -> |
| 122 | + let shifts = List.map (func directive) list in |
| 123 | + join_context_shifts shifts |
| 124 | + |
| 125 | +(* Transforms a list of context shift maps into a single context shift map. |
| 126 | + Assumes maps context shifts contain identical shifts for a variable. Recursive variables |
| 127 | + that simultaneously have different shift directives have indeterminate behavior *) |
| 128 | +and join_context_shifts (context_shifts : context_shifts list) : context_shifts |
| 129 | + = |
| 130 | + List.fold_left join_context_shift_binary IntMap.empty context_shifts |
| 131 | + |
| 132 | +and join_context_shift_binary (context_shift_a : context_shifts) |
| 133 | + (context_shift_b : context_shifts) : context_shifts = |
| 134 | + IntMap.merge |
| 135 | + (fun _ left_val right_val -> |
| 136 | + match (left_val, right_val) with |
| 137 | + (* NOTE: we assume that shifts in both maps are identical. Behavior is indeterminate otherwise *) |
| 138 | + | Some x, Some _ -> Some x |
| 139 | + | None, y -> y |
| 140 | + | x, None -> x) |
| 141 | + context_shift_a context_shift_b |
0 commit comments