Skip to content

Commit 19fceae

Browse files
committed
Refactor abstraction typing to prevent bugs related to shifting recursive type variables
1 parent ba90e7b commit 19fceae

File tree

3 files changed

+75
-83
lines changed

3 files changed

+75
-83
lines changed

src/structured/termOperations/eval.ml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,21 @@ and eval_rec (term : term) (env : environment) : value =
3131
| VUnivQuantifier _ ->
3232
raise (Invalid_argument "Cannot apply a universal quantifier"))
3333
(* Perform implicit substitution with universal quantifier application for simplicity *)
34-
| UnivApplication (app_term, app_type) ->
35-
match eval_rec app_term env with
36-
| VUnivQuantifier inner_term ->
37-
let substituted_inner_term = substitute_univ_var_term app_type inner_term in
38-
eval_rec substituted_inner_term env
39-
| VConst _ -> raise (Invalid_argument "Cannot perform universal application on a constant")
40-
| Closure _ -> raise (Invalid_argument "Cannot perform universal application on a abstraction")
34+
| UnivApplication (app_term, app_type) -> (
35+
match eval_rec app_term env with
36+
| VUnivQuantifier inner_term ->
37+
let substituted_inner_term =
38+
substitute_univ_var_term app_type inner_term
39+
in
40+
eval_rec substituted_inner_term env
41+
| VConst _ ->
42+
raise
43+
(Invalid_argument
44+
"Cannot perform universal application on a constant")
45+
| Closure _ ->
46+
raise
47+
(Invalid_argument
48+
"Cannot perform universal application on a abstraction"))
4149

4250
(* Determines the branch of the abstraction to execute, based on the type of the argument *)
4351
and resolve_branch branches argument =

src/structured/termOperations/typing.ml

Lines changed: 40 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,22 @@ module TypeContextMap = Map.Make (struct
1515
let compare = compare
1616
end)
1717

18-
type type_context_map = union_type TypeContextMap.t
18+
type type_context_map = structured_type TypeContextMap.t
1919

2020
(** [type_lambda_term term] determines the type of a term, if it is well-typed *)
21-
let rec get_type (term : term) = get_type_rec term TypeContextMap.empty (-1) []
21+
let rec get_type (term : term) = get_type_rec term TypeContextMap.empty (-1)
2222

2323
(* TODO: is this type context just an environment? Could I simplify by prepending to the front of the list to avoid the level? *)
24-
and get_type_rec (term : term) (type_context : type_context_map) (level : int)
25-
(recursive_context : recursive_context) : structured_type option =
24+
and get_type_rec (term : term) (type_context : type_context_map) (level : int) :
25+
structured_type option =
2626
match term with
2727
(* Constants always have label types *)
28-
| Const name -> Some (build_structured_type [ Label name ] recursive_context)
28+
(* TODO: use the base_to_structured function for this. Need to move it into TypeOperation, I think *)
29+
| Const name -> Some (build_structured_type [ Label name ] [])
2930
(* Use the helper function to determine if an application is well-typed *)
3031
| Application (t1, t2) ->
31-
let left_type = get_type_rec t1 type_context level recursive_context in
32-
let right_type = get_type_rec t2 type_context level recursive_context in
32+
let left_type = get_type_rec t1 type_context level in
33+
let right_type = get_type_rec t2 type_context level in
3334
flat_map_opt2 get_application_type left_type right_type
3435
(* Abstractions are well-typed if their argument types don't match
3536
The return types of the body can be inferred recursively from the argument type *)
@@ -45,78 +46,44 @@ and get_type_rec (term : term) (type_context : type_context_map) (level : int)
4546
in
4647
if not disjoint_args then None
4748
else
48-
(* TODO: should we fold right instead? The direction shouldn't matter and we need to append elements to the end of the list *)
49-
(* The approach here is to always append to the end of the recursive context. So we add the argument's recursive context to our
50-
current context, pass that down recursively, and whatever comes back should have just appended to the recursive context, so we
51-
can use that context. We fold over all the branches to accumulate a single recursive context and intersection type *)
52-
let intersection_type_opt =
53-
List.fold_left
54-
(fun acc (arg_branch_type, branch_body) ->
55-
match acc with
56-
| None -> None
57-
| Some (acc_union_type, acc_recursive_context) ->
58-
(* TODO: investigate replacing this call with a call to get_unified_type_context *)
59-
let new_arg_type =
60-
get_type_in_context arg_branch_type acc_recursive_context
61-
in
62-
let body_type_opt =
63-
get_type_rec branch_body
64-
(TypeContextMap.add (level + 1) new_arg_type.union
65-
type_context)
66-
(level + 1) new_arg_type.context
67-
in
68-
Option.map
69-
(fun body_type ->
70-
( acc_union_type
71-
@ [ (new_arg_type.union, body_type.union) ],
72-
body_type.context ))
73-
body_type_opt)
74-
(Some ([], recursive_context))
49+
let arg_types = List.map (fun (arg_type, _) -> arg_type) definitions in
50+
let body_opt_types =
51+
List.map
52+
(fun (arg_type, body) ->
53+
let new_type_context =
54+
TypeContextMap.add (level + 1) arg_type type_context
55+
in
56+
get_type_rec body new_type_context (level + 1))
7557
definitions
7658
in
59+
let body_types_opt = opt_list_to_list_opt body_opt_types in
7760
Option.map
78-
(fun (intersection_type, recursive_context) ->
79-
build_structured_type
80-
[ Intersection intersection_type ]
81-
recursive_context)
82-
intersection_type_opt
83-
(* The type of a variable is variable is based on the type of the argument in the abstraction defining it *)
84-
| Variable var_num ->
85-
let union_type_opt =
86-
TypeContextMap.find_opt (level - var_num) type_context
87-
in
88-
Option.map
89-
(fun union_type -> build_structured_type union_type recursive_context)
90-
union_type_opt
61+
(fun body_types -> unify_function_types arg_types body_types)
62+
body_types_opt
63+
(* The type of a variable is based on the type of the argument in the abstraction defining it *)
64+
| Variable var_num -> TypeContextMap.find_opt (level - var_num) type_context
9165
(* Determine the type within the quantifier, then merge the recursive contexts and build the appropriate union type *)
9266
| UnivQuantifier inner_term ->
93-
let inner_type_opt =
94-
get_type_rec inner_term type_context level recursive_context
95-
in
67+
let inner_type_opt = get_type_rec inner_term type_context level in
9668
Option.map
9769
(fun inner_type ->
98-
let recontextualized_inner =
99-
(* TODO: investigate replacing this with a call to get_unified_type_context_pair *)
100-
get_type_in_context inner_type recursive_context
101-
in
10270
build_structured_type
103-
[ UnivQuantification recontextualized_inner.union ]
104-
recontextualized_inner.context)
71+
[ UnivQuantification inner_type.union ]
72+
inner_type.context)
10573
inner_type_opt
10674
| UnivApplication (inner_term, inner_type) ->
107-
let inner_term_type_opt =
108-
get_type_rec inner_term type_context level recursive_context
109-
in
110-
Option.join (Option.map
111-
(fun inner_term_type ->
112-
get_univ_application_type inner_term_type inner_type)
113-
inner_term_type_opt)
75+
let inner_term_type_opt = get_type_rec inner_term type_context level in
76+
Option.join
77+
(Option.map
78+
(fun inner_term_type ->
79+
get_univ_application_type inner_term_type inner_type)
80+
inner_term_type_opt)
11481

11582
(** [get_application_type func_type arg_type] determines the resulting type of
11683
applying a term of type [arg_type] to a term of type [func_type], if
11784
the function can be applied to the argument *)
118-
and get_application_type (func : structured_type)
119-
(arg : structured_type) : structured_type option =
85+
and get_application_type (func : structured_type) (arg : structured_type) :
86+
structured_type option =
12087
(* Flatten the func type so only labels and intersection types remain *)
12188
let func_flat = flatten_union func.union func.context in
12289
(* The argument should be applicable to any function in the union, so acquire the type of applying the arg to each option *)
@@ -160,25 +127,23 @@ and get_application_option_type
160127
[] functions)
161128

162129
and get_univ_application_type (quantifier : structured_type)
163-
(type_arg : structured_type): structured_type option =
130+
(type_arg : structured_type) : structured_type option =
164131
(* Flatten the func type to get rid of recursive types *)
165132
let quantifier_flat = flatten_union quantifier.union quantifier.context in
166133
(* The type argument is applicable to any universal quantification in the union, so determine the types resulting
167134
from applying the type argument to each universal quantification in the union *)
168135
let return_opt_types =
169136
List.map
170137
(fun quant_option ->
171-
get_univ_application_option_type (quant_option, quantifier.context) type_arg)
138+
get_univ_application_option_type
139+
(quant_option, quantifier.context)
140+
type_arg)
172141
quantifier_flat
173142
in
174143
(* Aggregate the return types - if any of them were none, the application is not well-typed *)
175144
let return_types_opt = opt_list_to_list_opt return_opt_types in
176145
(* Combine all of the structured types, merging both the unions and and contexts *)
177-
Option.map (
178-
fun return_types -> (
179-
get_type_union return_types
180-
)
181-
) return_types_opt
146+
Option.map (fun return_types -> get_type_union return_types) return_types_opt
182147

183148
and get_univ_application_option_type
184149
((func_option, context1) : flat_base_type * recursive_context)
@@ -191,6 +156,6 @@ and get_univ_application_option_type
191156
(* If we had bounded quantification, we'd need to make sure the type argument provided is valid *)
192157
(* But for now, we just substitution in the inner type. The function handles shifting for us *)
193158
| FUnivQuantification inner_union_type ->
194-
(* Construct the complete inner type using the context *)
195-
let inner_type = build_structured_type inner_union_type context1 in
196-
Some (substitute_univ_var_type type_arg inner_type)
159+
(* Construct the complete inner type using the context *)
160+
let inner_type = build_structured_type inner_union_type context1 in
161+
Some (substitute_univ_var_type type_arg inner_type)

src/structured/typeOperations/context.ml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,29 @@ let get_unified_type_context (types : structured_type list) =
5656
List.fold_left
5757
(fun (acc_union, acc_context) next_type ->
5858
let recontextualized_next_union = get_type_in_context next_type acc_context in
59-
let new_acc_union = recontextualized_next_union :: acc_union in
59+
let new_acc_union = recontextualized_next_union.union :: acc_union in
6060
let new_acc_context = recontextualized_next_union.context in
6161
(new_acc_union, new_acc_context))
6262
([], []) types in
6363
(* We must reverse the list of unions since we fold left but want to keep the types in the right order *)
6464
let new_unions = List.rev new_unions_rev in
6565
new_unions, new_context
66+
67+
(* TODO: consider writing more dedicated logic for this rather than the showving intermediate into intersection *)
68+
(* Takes a list of arg types and their corresponding body types, and joined them into
69+
a single structured type for the intersection of the functions *)
70+
let unify_function_types (arg_types: structured_type list) (body_types: structured_type list) =
71+
(* First, build individual unary function types for each arg/body pair *)
72+
let unary_types = List.map2 (fun arg_type body_type ->
73+
let (new_arg_type, new_body_type), new_context = get_unified_type_context_pair arg_type body_type in
74+
build_structured_type [ Intersection [(new_arg_type, new_body_type )]] new_context
75+
) arg_types body_types in
76+
(* Then, rectonextualize all of them so we can prepare to join them into a single type *)
77+
let new_unary_unions, new_context = get_unified_type_context unary_types in
78+
(* Then destructure all of the unary types to build a single intersection type *)
79+
let unary_list = List.fold_left (fun acc_func_types next_union ->
80+
match next_union with
81+
| [ Intersection [ next_unary ]] -> next_unary::acc_func_types
82+
| _ -> raise (Failure "there was a problem destructuring the unary function types")
83+
) [] new_unary_unions in
84+
build_structured_type [ Intersection unary_list ] new_context

0 commit comments

Comments
 (0)