Skip to content

Commit

Permalink
Use a more relaxed mode for unification in Ctype.subst (#11771) (#73)
Browse files Browse the repository at this point in the history
This adds a unification mode for use when unifying type arguments with
their associated parameter. These unifications shouldn't generate any
equations and can be more relaxed with respect to
`trace_gadt_instances`.
  • Loading branch information
lpw25 authored Dec 5, 2022
1 parent cbd791a commit 2a7e501
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 63 deletions.
42 changes: 42 additions & 0 deletions testsuite/tests/typing-misc/constraints.ml
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,45 @@ end;;
[%%expect{|
Exception: Failure "Default_extension failure".
|}]


(* PR#11771 -- Constraints making expansion affect typeability *)
type foo = Foo
type bar = Bar

type _ tag =
| Foo_tag : foo tag
| Bar_tag : bar tag

type ('a, 'self) obj =
< foo : foo -> 'a ; bar : bar -> 'a; .. > as 'self
[%%expect {|
type foo = Foo
type bar = Bar
type _ tag = Foo_tag : foo tag | Bar_tag : bar tag
type ('a, 'self) obj = 'self
constraint 'self = < bar : bar -> 'a; foo : foo -> 'a; .. >
|}]

let test_obj_no_expansion :
type a b. a tag -> < foo : foo -> b ; bar : bar -> b; .. > -> a -> b =
fun t obj x ->
match t with
| Foo_tag -> obj#foo x
| Bar_tag -> obj#bar x
[%%expect {|
val test_obj_no_expansion :
'a tag -> < bar : bar -> 'b; foo : foo -> 'b; .. > -> 'a -> 'b = <fun>
|}]

let test_obj_with_expansion :
type a b. a tag -> (b, _) obj -> a -> b =
fun t obj x ->
match t with
| Foo_tag -> obj#foo x
| Bar_tag -> obj#bar x
[%%expect {|
val test_obj_with_expansion :
'a tag -> ('b, < bar : bar -> 'b; foo : foo -> 'b; .. >) obj -> 'a -> 'b =
<fun>
|}]
159 changes: 99 additions & 60 deletions typing/ctype.ml
Original file line number Diff line number Diff line change
Expand Up @@ -248,31 +248,81 @@ let none = newty (Ttuple []) (* Clearly ill-formed type *)

(**** unification mode ****)

type unification_mode =
| Expression (* unification in expression *)
| Pattern (* unification in pattern which may add local constraints *)

type equations_generation =
| Forbidden
| Allowed of { equated_types : TypePairs.t }

type unification_mode =
| Expression (* unification in expression *)
| Pattern of
{ equations_generation : equations_generation;
assume_injective : bool;
allow_recursive_equations : bool; }
(* unification in pattern which may add local constraints *)
| Subst
(* unification during type constructor expansion; more
relaxed than [Expression] in some cases. *)

let umode = ref Expression
let equations_generation = ref Forbidden
let assume_injective = ref false
let allow_recursive_equation = ref false

let in_pattern_mode () =
match !umode with
| Expression | Subst -> false
| Pattern _ -> true

let in_subst_mode () =
match !umode with
| Expression | Pattern _ -> false
| Subst -> true

let can_generate_equations () =
match !equations_generation with
| Forbidden -> false
| _ -> true

let set_mode_pattern ~generate ~injective ~allow_recursive f =
Misc.protect_refs
[ Misc.R (umode, Pattern);
Misc.R (equations_generation, generate);
Misc.R (assume_injective, injective);
Misc.R (allow_recursive_equation, allow_recursive);
] f
match !umode with
| Expression | Subst | Pattern { equations_generation = Forbidden } -> false
| Pattern { equations_generation = Allowed _ } -> true

(* Can only be called when generate_equations is true *)
let record_equation t1 t2 =
match !umode with
| Expression | Subst | Pattern { equations_generation = Forbidden } ->
assert false
| Pattern { equations_generation = Allowed { equated_types } } ->
TypePairs.add equated_types (t1, t2)

let can_assume_injective () =
match !umode with
| Expression | Subst -> false
| Pattern { assume_injective } -> assume_injective

let allow_recursive_equations () =
!Clflags.recursive_types
|| match !umode with
| Expression | Subst -> false
| Pattern { allow_recursive_equations } -> allow_recursive_equations

let set_mode_pattern ~allow_recursive_equations ~equated_types f =
let equations_generation = Allowed { equated_types } in
let assume_injective = true in
let new_umode =
Pattern
{ equations_generation;
assume_injective;
allow_recursive_equations }
in
Misc.protect_refs [ Misc.R (umode, new_umode) ] f

let without_assume_injective f =
match !umode with
| Expression | Subst -> f ()
| Pattern r ->
let new_umode = Pattern { r with assume_injective = false } in
Misc.protect_refs [ Misc.R (umode, new_umode) ] f

let without_generating_equations f =
match !umode with
| Expression | Subst -> f ()
| Pattern r ->
let new_umode = Pattern { r with equations_generation = Forbidden } in
Misc.protect_refs [ Misc.R (umode, new_umode) ] f

(*** Checks for type definitions ***)

Expand Down Expand Up @@ -1484,13 +1534,17 @@ let subst env level priv abbrev oty params args body =
abbreviations := abbrev;
let (params', body') = instance_parameterized_type params body in
abbreviations := ref Mnil;
let old_umode = !umode in
umode := Subst;
try
!unify_var' env body0 body';
List.iter2 (!unify_var' env) params' args;
current_level := old_level;
umode := old_umode;
body'
with Unify _ ->
current_level := old_level;
umode := old_umode;
undo_abbrev ();
raise Cannot_subst

Expand Down Expand Up @@ -1820,8 +1874,7 @@ let type_changed = ref false (* trace possible changes to the studied type *)
let merge r b = if b then r := true

let occur env ty0 ty =
let allow_recursive =
!Clflags.recursive_types || !umode = Pattern && !allow_recursive_equation in
let allow_recursive = allow_recursive_equations () in
let old = !type_changed in
try
while
Expand Down Expand Up @@ -1881,8 +1934,7 @@ let rec local_non_recursive_abbrev ~allow_rec strict visited env p ty =
end

let local_non_recursive_abbrev env p ty =
let allow_rec =
!Clflags.recursive_types || !umode = Pattern && !allow_recursive_equation in
let allow_rec = allow_recursive_equations () in
try (* PR#7397: need to check trace_gadt_instances *)
wrap_trace_gadt_instances env
(local_non_recursive_abbrev ~allow_rec false [] env p) ty;
Expand Down Expand Up @@ -2606,11 +2658,9 @@ let unify_alloc_mode_for tr_exn a b =
let rigid_variants = ref false

let unify_eq t1 t2 =
eq_type t1 t2 ||
match !umode with
| Expression -> false
| Pattern ->
TypePairs.mem unify_eq_set (order_type_pair t1 t2)
eq_type t1 t2
|| (in_pattern_mode ()
&& TypePairs.mem unify_eq_set (order_type_pair t1 t2))

let unify1_var env t1 t2 =
assert (is_Tvar t1);
Expand All @@ -2626,22 +2676,15 @@ let unify1_var env t1 t2 =
end;
link_type t1 t2;
true
| exception Unify_trace _ when !umode = Pattern ->
| exception Unify_trace _ when in_pattern_mode () ->
false

(* Can only be called when generate_equations is true *)
let record_equation t1 t2 =
match !equations_generation with
| Forbidden -> assert false
| Allowed { equated_types } ->
TypePairs.add equated_types (t1, t2)

(* Called from unify3 *)
let unify3_var env t1' t2 t2' =
occur_for Unify !env t1' t2;
match occur_univar_for Unify !env t2 with
| () -> link_type t1' t2
| exception Unify_trace _ when !umode = Pattern ->
| exception Unify_trace _ when in_pattern_mode () ->
reify env t1';
reify env t2';
if can_generate_equations () then begin
Expand Down Expand Up @@ -2773,20 +2816,19 @@ and unify3 env t1 t1' t2 t2' =
| (Tfield _, Tfield _) -> (* special case for GADTs *)
unify_fields env t1' t2'
| _ ->
begin match !umode with
| Expression ->
occur_for Unify !env t1' t2;
link_type t1' t2
| Pattern ->
add_type_equality t1' t2'
if in_pattern_mode () then
add_type_equality t1' t2'
else begin
occur_for Unify !env t1' t2;
link_type t1' t2
end;
try
begin match (d1, d2) with
(Tarrow ((l1,a1,r1), t1, u1, c1),
Tarrow ((l2,a2,r2), t2, u2, c2))
when
(l1 = l2 ||
(!Clflags.classic || !umode = Pattern) &&
(!Clflags.classic || in_pattern_mode ()) &&
not (is_optional l1 || is_optional l2)) ->
unify_alloc_mode_for Unify a1 a2;
unify_alloc_mode_for Unify r1 r2;
Expand All @@ -2800,12 +2842,10 @@ and unify3 env t1 t1' t2 t2' =
| (Ttuple tl1, Ttuple tl2) ->
unify_list env tl1 tl2
| (Tconstr (p1, tl1, _), Tconstr (p2, tl2, _)) when Path.same p1 p2 ->
if !umode = Expression || !equations_generation = Forbidden then
if not (can_generate_equations ()) then
unify_list env tl1 tl2
else if !assume_injective then
set_mode_pattern ~generate:!equations_generation ~injective:false
~allow_recursive:!allow_recursive_equation
(fun () -> unify_list env tl1 tl2)
else if can_assume_injective () then
without_assume_injective (fun () -> unify_list env tl1 tl2)
else if in_current_module p1 (* || in_pervasives p1 *)
|| List.exists (expands_to_datatype !env) [t1'; t1; t2]
then
Expand All @@ -2819,8 +2859,7 @@ and unify3 env t1 t1' t2 t2' =
List.iter2
(fun i (t1, t2) ->
if i then unify env t1 t2 else
set_mode_pattern ~generate:Forbidden ~injective:false
~allow_recursive:!allow_recursive_equation
without_generating_equations
begin fun () ->
let snap = snapshot () in
try unify env t1 t2 with Unify_trace _ ->
Expand Down Expand Up @@ -2850,7 +2889,7 @@ and unify3 env t1 t1' t2 t2' =
reify env t1';
record_equation t1' t2';
add_gadt_equation env path t1'
| (Tconstr (_,_,_), _) | (_, Tconstr (_,_,_)) when !umode = Pattern ->
| (Tconstr (_,_,_), _) | (_, Tconstr (_,_,_)) when in_pattern_mode () ->
reify env t1';
reify env t2';
if can_generate_equations () then (
Expand All @@ -2869,7 +2908,7 @@ and unify3 env t1 t1' t2 t2' =
| _ -> ()
end
| (Tvariant row1, Tvariant row2) ->
if !umode = Expression then
if not (in_pattern_mode ()) then
unify_row env row1 row2
else begin
let snap = snapshot () in
Expand Down Expand Up @@ -2908,7 +2947,7 @@ and unify3 env t1 t1' t2 t2' =
unify_package !env (unify_list env)
(get_level t1) p1 fl1 (get_level t2) p2 fl2
with Not_found ->
if !umode = Expression then raise_unexplained_for Unify;
if not (in_pattern_mode ()) then raise_unexplained_for Unify;
List.iter (fun (_n, ty) -> reify env ty) (fl1 @ fl2);
(* if !generate_equations then List.iter2 (mcomp !env) tl1 tl2 *)
end
Expand Down Expand Up @@ -2974,7 +3013,8 @@ and unify_fields env ty1 ty2 = (* Optimization *)
(fun (name, k1, t1, k2, t2) ->
unify_kind k1 k2;
try
if !trace_gadt_instances then begin
if !trace_gadt_instances && not (in_subst_mode ()) then begin
(* in_subst_mode: see PR#11771 *)
update_level_for Unify !env (get_level va) t1;
update_scope_for Unify (get_scope va) t1
end;
Expand Down Expand Up @@ -3066,7 +3106,8 @@ and unify_row env row1 row2 =
(* The following test is not principal... should rather use Tnil *)
let rm = row_more row in
(*if !trace_gadt_instances && rm.desc = Tnil then () else*)
if !trace_gadt_instances then
if !trace_gadt_instances && not (in_subst_mode ()) then
(* in_subst_mode: see PR#11771 *)
update_level_for Unify !env (get_level rm) (newgenty (Tvariant row));
if has_fixed_explanation row then
if eq_type more rm then () else
Expand Down Expand Up @@ -3202,15 +3243,13 @@ let unify env ty1 ty2 =
undo_compress snap;
raise (Unify (expand_to_unification_error !env trace))

let unify_gadt ~equations_level:lev ~allow_recursive (env:Env.t ref) ty1 ty2 =
let unify_gadt ~equations_level:lev ~allow_recursive_equations
(env:Env.t ref) ty1 ty2 =
try
univar_pairs := [];
gadt_equations_level := Some lev;
let equated_types = TypePairs.create 0 in
set_mode_pattern
~generate:(Allowed { equated_types })
~injective:true
~allow_recursive
set_mode_pattern ~allow_recursive_equations ~equated_types
(fun () -> unify env ty1 ty2);
gadt_equations_level := None;
TypePairs.clear unify_eq_set;
Expand Down
2 changes: 1 addition & 1 deletion typing/ctype.mli
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ val extract_concrete_typedecl:
val unify: Env.t -> type_expr -> type_expr -> unit
(* Unify the two types given. Raise [Unify] if not possible. *)
val unify_gadt:
equations_level:int -> allow_recursive:bool ->
equations_level:int -> allow_recursive_equations:bool ->
Env.t ref -> type_expr -> type_expr -> Btype.TypePairs.t
(* Unify the two types given and update the environment with the
local constraints. Raise [Unify] if not possible.
Expand Down
4 changes: 2 additions & 2 deletions typing/typecore.ml
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,9 @@ let nothing_equated = TypePairs.create 0
let unify_pat_types_return_equated_pairs ?(refine = None) loc env ty ty' =
try
match refine with
| Some allow_recursive ->
| Some allow_recursive_equations ->
unify_gadt ~equations_level:(get_gadt_equations_level ())
~allow_recursive env ty ty'
~allow_recursive_equations env ty ty'
| None ->
unify !env ty ty';
nothing_equated
Expand Down

0 comments on commit 2a7e501

Please sign in to comment.