Skip to content

Commit 5468434

Browse files
committed
Add simple functions and tests for inductive arithmetic
1 parent f2d1963 commit 5468434

12 files changed

+342
-29
lines changed

src/structured/termOperations/substituteUnivVar.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ and substitute_univ_var_type_rec (variable_num : int)
7474
=
7575
(* Get the in_type in the context of the with_type so we can safely substitute
7676
in the with_type while any recursive type variables in the with_type still reference the same types *)
77+
(* TODO: replace this call with a call to get_unified_type_context_pair *)
7778
let recontextualized_type = get_type_in_context in_type with_type.context in
7879
(* Substitute the universal type variables in the recontextualized type context *)
7980
let new_context = substitute_univ_var_context variable_num with_type recontextualized_type.context in

src/structured/termOperations/typing.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ and get_type_rec (term : term) (type_context : type_context_map) (level : int)
5555
match acc with
5656
| None -> None
5757
| Some (acc_union_type, acc_recursive_context) ->
58+
(* TODO: investigate replacing this call with a call to get_unified_type_context *)
5859
let new_arg_type =
5960
get_type_in_context arg_branch_type acc_recursive_context
6061
in
@@ -95,6 +96,7 @@ and get_type_rec (term : term) (type_context : type_context_map) (level : int)
9596
Option.map
9697
(fun inner_type ->
9798
let recontextualized_inner =
99+
(* TODO: investigate replacing this with a call to get_unified_type_context_pair *)
98100
get_type_in_context inner_type recursive_context
99101
in
100102
build_structured_type

src/structured/typeOperations/context.ml

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ and shift_rec_type_vars_base (amount : int) (base_type : base_type) =
1111
| RecTypeVar n -> RecTypeVar (n + amount)
1212
| Intersection functions ->
1313
Intersection (List.map (shift_rec_type_vars_func amount) functions)
14-
| UnivQuantification t -> UnivQuantification (shift_rec_type_vars_union amount t)
14+
| UnivQuantification t ->
15+
UnivQuantification (shift_rec_type_vars_union amount t)
1516

1617
and shift_rec_type_vars_func (amount : int) ((arg, return) : unary_function) =
1718
(shift_rec_type_vars_union amount arg, shift_rec_type_vars_union amount return)
@@ -20,13 +21,17 @@ and shift_rec_type_vars_context (amount : int) (context : recursive_context) =
2021
List.map (shift_rec_type_vars_def amount) context
2122

2223
and shift_rec_type_vars_def (amount : int) (recursive_def : recursive_def) =
23-
let shifted_union = List.map (fun flat_base ->
24-
match flat_base with
25-
| FLabel _ | FUnivTypeVar _ -> flat_base
26-
| FIntersection functions ->
27-
FIntersection (List.map (shift_rec_type_vars_func amount) functions)
28-
| FUnivQuantification t -> FUnivQuantification (shift_rec_type_vars_union amount t)
29-
) recursive_def.flat_union in
24+
let shifted_union =
25+
List.map
26+
(fun flat_base ->
27+
match flat_base with
28+
| FLabel _ | FUnivTypeVar _ -> flat_base
29+
| FIntersection functions ->
30+
FIntersection (List.map (shift_rec_type_vars_func amount) functions)
31+
| FUnivQuantification t ->
32+
FUnivQuantification (shift_rec_type_vars_union amount t))
33+
recursive_def.flat_union
34+
in
3035
build_recursive_def recursive_def.kind shifted_union
3136

3237
let get_type_in_context (t : structured_type)
@@ -36,4 +41,25 @@ let get_type_in_context (t : structured_type)
3641
let new_union =
3742
shift_rec_type_vars_union (List.length recursive_context) t.union
3843
in
39-
build_structured_type new_union (recursive_context @ new_context)
44+
build_structured_type new_union (recursive_context @ new_context)
45+
46+
(* Converts a pair of structured types into a pair of union types that share a context *)
47+
let get_unified_type_context_pair (typea: structured_type) (typeb: structured_type) =
48+
let recontextualized_typeb = get_type_in_context typeb typea.context in
49+
let new_typeb = recontextualized_typeb.union in
50+
((typea.union, new_typeb), recontextualized_typeb.context)
51+
52+
(* Converts a list of structured types into a list of union types and a common context they all share,
53+
shifting recursive type variables in the union as appropriate *)
54+
let get_unified_type_context (types : structured_type list) =
55+
let new_unions_rev, new_context =
56+
List.fold_left
57+
(fun (acc_union, acc_context) next_type ->
58+
let recontextualized_next_union = get_type_in_context next_type acc_context in
59+
let new_acc_union = recontextualized_next_union :: acc_union in
60+
let new_acc_context = recontextualized_next_union.context in
61+
(new_acc_union, new_acc_context))
62+
([], []) types in
63+
(* We must reverse the list of unions since we fold left but want to keep the types in the right order *)
64+
let new_unions = List.rev new_unions_rev in
65+
new_unions, new_context

src/structured/typeOperations/union.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ let get_type_union (types : structured_type list) : structured_type =
1212
(fun (acc_union_type, acc_recursive_context) next_type ->
1313
(* Get the next type in the accumulated context to obtain the joined context
1414
and the next type in the context of that new accumulated context *)
15+
(* TODO: investigate replacing this with a call to get_unified_type_context *)
1516
let new_type = get_type_in_context next_type acc_recursive_context in
1617
(* Append this union to the accumulated union, use the new context from above *)
1718
(acc_union_type @ new_type.union, new_type.context))

src/structuredArithmetic.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ let decrement =
5757
(seven.stype, six.term);
5858
])
5959

60-
let fix_binary_num_op = fix three_bit_type.union unary_num_op.union
61-
let fix_unary_num_op = fix three_bit_type.union three_bit_type.union
60+
let fix_binary_num_op = fix three_bit_type unary_num_op
61+
let fix_unary_num_op = fix three_bit_type three_bit_type
6262

6363
let add =
6464
get_typed_term_unsafe

src/structuredHelpers.ml

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ open Structured.TermTypes
33
open Structured.TypeOperations.Create
44
open Structured.TermOperations.Typing
55
open Structured.TypeOperations.Union
6+
open TypeOperations.Context
67

78
type typed_term = { term : term; stype : structured_type }
89

@@ -46,39 +47,56 @@ let get_flat_union_type (union_types : structured_type list) : flat_union_type =
4647

4748
(* Constructs the Z-combinator for a function of a given type, a fixed-point
4849
combinator for call-by-value semantics *)
49-
let build_fix (arg_type : union_type) (return_type : union_type) =
50-
let func_type = (func_to_structured_type (arg_type, return_type)).union in
51-
let fix_context =
52-
build_recursive_context
53-
[ (Coinductive, [ FIntersection [ ([ RecTypeVar 0 ], func_type) ] ]) ]
50+
let build_fix (arg_type : structured_type) (return_type : structured_type) =
51+
(* First, construct a function type from the arg type to the return type, taking
52+
care to properly join the contexts of the two types *)
53+
let (new_arg_type, new_return_type), shared_context =
54+
get_unified_type_context_pair arg_type return_type
5455
in
56+
let func_type =
57+
build_structured_type
58+
[ Intersection [ (new_arg_type, new_return_type) ] ]
59+
shared_context
60+
in
61+
(* Next, build the recursive definition that we'll add to the end of the joined context *)
62+
let rec_var_num = List.length shared_context in
63+
let fix_rec_def =
64+
build_recursive_def Coinductive
65+
[ FIntersection [ ([ RecTypeVar rec_var_num ], func_type.union) ] ]
66+
in
67+
(* Then add that definition to the end of the context so it has the number we assigned it *)
68+
let new_shared_context = List.append shared_context [ fix_rec_def ] in
5569
let fix =
5670
get_typed_term_unsafe
5771
(Abstraction
5872
[
59-
( func_to_structured_type (func_type, func_type),
73+
( build_structured_type
74+
[ Intersection [ (func_type.union, func_type.union) ] ]
75+
new_shared_context,
6076
Application
6177
( Abstraction
6278
[
63-
( build_structured_type [ RecTypeVar 0 ] fix_context,
79+
( build_structured_type [ RecTypeVar rec_var_num ]
80+
new_shared_context,
6481
Application
6582
( Variable 1,
6683
Abstraction
6784
[
68-
( union_to_structured_type arg_type,
85+
( build_structured_type new_arg_type new_shared_context,
6986
Application
7087
( Application (Variable 1, Variable 1),
7188
Variable 0 ) );
7289
] ) );
7390
],
7491
Abstraction
7592
[
76-
( build_structured_type [ RecTypeVar 0 ] fix_context,
93+
( build_structured_type [ RecTypeVar rec_var_num ]
94+
new_shared_context,
7795
Application
7896
( Variable 1,
7997
Abstraction
8098
[
81-
( union_to_structured_type arg_type,
99+
( build_structured_type new_arg_type new_shared_context,
82100
Application
83101
( Application (Variable 1, Variable 1),
84102
Variable 0 ) );
@@ -89,6 +107,7 @@ let build_fix (arg_type : union_type) (return_type : union_type) =
89107
fix
90108

91109
(* Fixes a provided abstraction with the given arg and return type *)
92-
let fix (arg_type : union_type) (return_type : union_type) (term : term) =
110+
let fix (arg_type : structured_type) (return_type : structured_type)
111+
(term : term) =
93112
let fix_term = build_fix arg_type return_type in
94113
Application (fix_term.term, term)

src/structuredMixed.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ let is_zero =
3030
false_lambda.term );
3131
])
3232

33-
let fix_even_odd = fix is_even_odd_label.union num_to_bool.union
33+
let fix_even_odd = fix is_even_odd_label num_to_bool
3434

3535
let is_even_odd =
3636
get_typed_term_unsafe

src/structuredPoly.ml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,18 @@ let tail =
181181
Application (Variable 0, next_label.term) );
182182
(empty_list.stype, none_label.term)
183183
]))
184+
185+
(* List functions we should implement:
186+
* Length
187+
* nth
188+
* reverse
189+
* concat
190+
* append (add a single element to the end)
191+
* flatten
192+
* equal
193+
* map
194+
* filter
195+
* fold_left/fold_right
196+
* find (return element and/or index)
197+
* forall/exists
198+
*)

src/structuredRecursive.ml

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ open Structured.TermTypes
33
open Structured.TypeOperations.Create
44
open TypeOperations.Union
55
open StructuredHelpers
6+
open StructuredBool
67

78
let name = get_typed_term_unsafe (Const "Name")
89
let val_lambda = get_typed_term_unsafe (Const "Val")
@@ -192,3 +193,151 @@ let neg_infinity =
192193
[ (Coinductive, get_flat_union_type [ generate_pred_rec_step 1 ]) ])
193194

194195
let infinity = get_type_union [ pos_infinity; neg_infinity ]
196+
197+
let unary_numerical_op =
198+
build_structured_type
199+
[ Intersection [ (ind_integer.union, ind_integer.union) ] ]
200+
ind_integer.context
201+
202+
let binary_numerical_op =
203+
build_structured_type
204+
[
205+
Intersection
206+
[
207+
( ind_integer.union,
208+
[ Intersection [ (ind_integer.union, ind_integer.union) ] ] );
209+
];
210+
]
211+
ind_integer.context
212+
213+
let num_to_bool_op =
214+
build_structured_type
215+
[ Intersection [ (ind_integer.union, bool_type.union) ] ]
216+
ind_integer.context
217+
218+
let binary_num_to_bool_op =
219+
build_structured_type
220+
[
221+
Intersection
222+
[
223+
( ind_integer.union,
224+
[ Intersection [ (ind_integer.union, bool_type.union) ] ] );
225+
];
226+
]
227+
ind_integer.context
228+
229+
(* Increments an inductive number by one *)
230+
let increment =
231+
get_typed_term_unsafe
232+
(Abstraction
233+
[
234+
(ind_negative_number, Application (Variable 0, val_lambda.term));
235+
( get_type_union [ zero.stype; ind_positive_number ],
236+
Abstraction
237+
[ (name.stype, succ.term); (val_lambda.stype, Variable 1) ] );
238+
])
239+
240+
(* Decrements an inductive number by one *)
241+
let decrement =
242+
get_typed_term_unsafe
243+
(Abstraction
244+
[
245+
( get_type_union [ zero.stype; ind_negative_number ],
246+
Abstraction
247+
[ (name.stype, pred.term); (val_lambda.stype, Variable 1) ] );
248+
(ind_positive_number, Application (Variable 0, val_lambda.term));
249+
])
250+
251+
(* Determines if a value is even or odd, leveraging the subtyping system *)
252+
let is_even =
253+
get_typed_term_unsafe
254+
(Abstraction
255+
[
256+
(ind_even_integer, true_lambda.term);
257+
(ind_odd_integer, false_lambda.term);
258+
])
259+
260+
let fix_binary_num_to_bool = fix ind_integer num_to_bool_op
261+
let fix_binary_num_op = fix ind_integer unary_numerical_op
262+
263+
let is_equal =
264+
get_typed_term_unsafe
265+
(fix_binary_num_to_bool
266+
(Abstraction
267+
[
268+
( binary_num_to_bool_op,
269+
Abstraction
270+
[
271+
( zero.stype,
272+
Abstraction
273+
[
274+
( get_type_union
275+
[ ind_positive_number; ind_negative_number ],
276+
false_lambda.term );
277+
(zero.stype, true_lambda.term);
278+
] );
279+
( ind_positive_number,
280+
Abstraction
281+
[
282+
( get_type_union [ zero.stype; ind_negative_number ],
283+
false_lambda.term );
284+
( ind_positive_number,
285+
Application
286+
( Application
287+
( Variable 2,
288+
Application (decrement.term, Variable 1) ),
289+
Application (decrement.term, Variable 0) ) );
290+
] );
291+
( ind_negative_number,
292+
Abstraction
293+
[
294+
( get_type_union [ zero.stype; ind_positive_number ],
295+
false_lambda.term );
296+
( ind_negative_number,
297+
Application
298+
( Application
299+
( Variable 2,
300+
Application (increment.term, Variable 1) ),
301+
Application (increment.term, Variable 0) ) );
302+
] );
303+
] );
304+
]))
305+
306+
let add =
307+
get_typed_term_unsafe
308+
(fix_binary_num_op
309+
(Abstraction
310+
[
311+
( binary_numerical_op,
312+
Abstraction
313+
[
314+
(zero.stype, Abstraction [ (ind_integer, Variable 0) ]);
315+
( ind_negative_number,
316+
Abstraction
317+
[
318+
( ind_integer,
319+
Application
320+
( Application
321+
( Variable 2,
322+
Application (increment.term, Variable 1) ),
323+
Application (decrement.term, Variable 0) ) );
324+
] );
325+
( ind_positive_number,
326+
Abstraction
327+
[
328+
( ind_integer,
329+
Application
330+
( Application
331+
( Variable 2,
332+
Application (decrement.term, Variable 1) ),
333+
Application (increment.term, Variable 0) ) );
334+
] );
335+
] );
336+
]))
337+
338+
(* Later, consider also implementing these functions *)
339+
(* subtract *)
340+
(* fibonnaci *)
341+
(* negate *)
342+
(* multiply *)
343+
(* divide *)

src/structuredUnions.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ let decrement_three_bit =
122122
(zero.stype, seven.term);
123123
])
124124

125-
let fix_binary_num_op = fix three_bit_num.union unary_num_type.union
125+
let fix_binary_num_op = fix three_bit_num unary_num_type
126126

127127
let add_three_bit =
128128
get_typed_term_unsafe

0 commit comments

Comments
 (0)