Skip to content

Better atomicity of mode system #2388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions ocaml/typing/mode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,10 @@ type changes = S.changes

let undo_changes = S.undo_changes

let set_append_changes = S.set_append_changes
(* To be filled in by [types.ml] *)
let append_changes : (changes ref -> unit) ref = ref (fun _ -> assert false)

let set_append_changes f = append_changes := f

(** Representing a single object *)
module type Obj = sig
Expand All @@ -1115,11 +1118,31 @@ module type Obj = sig
val obj : const C.obj
end

let equate_from_submode submode m0 m1 =
match submode m0 m1 with
let try_with_log op =
let log' = ref S.empty_changes in
let log = Some log' in
match op ~log with
| Ok _ as x ->
!append_changes log';
x
| Error _ as x ->
S.undo_changes !log';
x
[@@inline]

let with_log op =
let log' = ref S.empty_changes in
let log = Some log' in
let r = op ~log in
!append_changes log';
r
[@@inline]

let equate_from_submode submode_log m0 m1 ~log =
match submode_log m0 m1 ~log with
| Error e -> Error (Left_le_right, e)
| Ok () -> (
match submode m1 m0 with
match submode_log m1 m0 ~log with
| Error e -> Error (Right_le_left, e)
| Ok () -> Ok ())
[@@inline]
Expand Down Expand Up @@ -1159,15 +1182,17 @@ module Common (Obj : Obj) = struct

let newvar_below m = Solver.newvar_below obj m

let submode m0 m1 : (unit, error) result = Solver.submode obj m0 m1
let submode_log a b ~log = Solver.submode obj a b ~log

let submode a b = try_with_log (submode_log a b)

let join l = Solver.join obj l

let meet l = Solver.meet obj l

let submode_exn m0 m1 = assert (submode m0 m1 |> Result.is_ok)

let equate = equate_from_submode submode
let equate a b = try_with_log (equate_from_submode submode_log a b)

let equate_exn m0 m1 = assert (equate m0 m1 |> Result.is_ok)

Expand All @@ -1176,9 +1201,9 @@ module Common (Obj : Obj) = struct
then Solver.print_raw ?verbose obj ppf m
else Solver.print ?verbose obj ppf m

let zap_to_ceil m = Solver.zap_to_ceil obj m
let zap_to_ceil m = with_log (Solver.zap_to_ceil obj m)

let zap_to_floor m = Solver.zap_to_floor obj m
let zap_to_floor m = with_log (Solver.zap_to_floor obj m)

let of_const : type l r. const -> (l * r) t = fun a -> Solver.of_const obj a

Expand Down Expand Up @@ -1361,8 +1386,8 @@ module Comonadic_with_regionality = struct
let legacy = of_const Const.legacy

(* overriding to report the offending axis *)
let submode m0 m1 =
match submode m0 m1 with
let submode_log m0 m1 ~log =
match submode_log m0 m1 ~log with
| Ok () -> Ok ()
| Error { left = reg0, lin0; right = reg1, lin1 } ->
if Regionality.Const.le reg0 reg1
Expand All @@ -1372,8 +1397,10 @@ module Comonadic_with_regionality = struct
else Error (`Linearity { left = lin0; right = lin1 })
else Error (`Regionality { left = reg0; right = reg1 })

let submode a b = try_with_log (submode_log a b)

(* override to report the offending axis *)
let equate = equate_from_submode submode
let equate a b = try_with_log (equate_from_submode submode_log a b)

(** overriding to check per-axis *)
let check_const m =
Expand Down Expand Up @@ -1459,8 +1486,8 @@ module Comonadic_with_locality = struct
let legacy = of_const Const.legacy

(* overriding to report the offending axis *)
let submode m0 m1 =
match submode m0 m1 with
let submode_log m0 m1 ~log =
match submode_log m0 m1 ~log with
| Ok () -> Ok ()
| Error { left = loc0, lin0; right = loc1, lin1 } ->
if Locality.Const.le loc0 loc1
Expand All @@ -1470,8 +1497,10 @@ module Comonadic_with_locality = struct
else Error (`Linearity { left = lin0; right = lin1 })
else Error (`Locality { left = loc0; right = loc1 })

let submode a b = try_with_log (submode_log a b)

(* override to report the offending axis *)
let equate = equate_from_submode submode
let equate a b = try_with_log (equate_from_submode submode_log a b)

(** overriding to check per-axis *)
let check_const m =
Expand Down Expand Up @@ -1537,16 +1566,18 @@ module Monadic = struct
let legacy = of_const Const.legacy

(* overriding to report the offending axis *)
let submode m0 m1 =
match submode m0 m1 with
let submode_log m0 m1 ~log =
match submode_log m0 m1 ~log with
| Ok () -> Ok ()
| Error { left = uni0, (); right = uni1, () } ->
if Uniqueness.Const.le uni0 uni1
then assert false
else Error (`Uniqueness { left = uni0; right = uni1 })

let submode a b = try_with_log (submode_log a b)

(* override to report the offending axis *)
let equate = equate_from_submode submode
let equate a b = try_with_log (equate_from_submode submode_log a b)

(** overriding to check per-axis *)
let check_const m =
Expand Down Expand Up @@ -1636,19 +1667,20 @@ module Value = struct

type equate_error = equate_step * error

(* NB: state mutated when error *)
let submode { monadic = monadic0; comonadic = comonadic0 }
{ monadic = monadic1; comonadic = comonadic1 } =
let submode_log { monadic = monadic0; comonadic = comonadic0 }
{ monadic = monadic1; comonadic = comonadic1 } ~log =
(* comonadic before monadic, so that locality errors dominate
(error message backward compatibility) *)
match Comonadic.submode comonadic0 comonadic1 with
match Comonadic.submode_log comonadic0 comonadic1 ~log with
| Error e -> Error e
| Ok () -> (
match Monadic.submode monadic0 monadic1 with
match Monadic.submode_log monadic0 monadic1 ~log with
| Error e -> Error e
| Ok () -> Ok ())

let equate = equate_from_submode submode
let submode a b = try_with_log (submode_log a b)

let equate a b = try_with_log (equate_from_submode submode_log a b)

let submode_exn m0 m1 =
match submode m0 m1 with
Expand Down Expand Up @@ -1939,18 +1971,18 @@ module Alloc = struct

type equate_error = equate_step * error

(* NB: state mutated when error - should be fine as this always indicates type
error in typecore.ml which triggers backtracking. *)
let submode { monadic = monadic0; comonadic = comonadic0 }
{ monadic = monadic1; comonadic = comonadic1 } =
match Monadic.submode monadic0 monadic1 with
let submode_log { monadic = monadic0; comonadic = comonadic0 }
{ monadic = monadic1; comonadic = comonadic1 } ~log =
match Monadic.submode_log monadic0 monadic1 ~log with
| Error e -> Error e
| Ok () -> (
match Comonadic.submode comonadic0 comonadic1 with
match Comonadic.submode_log comonadic0 comonadic1 ~log with
| Error e -> Error e
| Ok () -> Ok ())

let equate = equate_from_submode submode
let submode a b = try_with_log (submode_log a b)

let equate a b = try_with_log (equate_from_submode submode_log a b)

let submode_exn m0 m1 =
match submode m0 m1 with
Expand Down
Loading
Loading