Skip to content

Fix case where local functions optimisation breaks tail calls #2360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions middle_end/flambda2/from_lambda/lambda_to_flambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ let compile_staticfail acc env ccenv ~(continuation : Continuation.t) ~args :
allocation region"
Continuation.print continuation;
let rec add_end_regions acc ~region_stack_now =
(* This can maybe only be exercised right now using "match with exception",
since that causes jumps out of try-regions (but not normal regions). *)
(* CR pchambart: This closes all the regions between region_stack_now and
region_stack_at_handler, but closing only the last one should be
sufficient. *)
Expand Down Expand Up @@ -184,7 +182,7 @@ let rec try_to_find_location (lam : L.lambda) =
| Llet (_, _, _, lam, _)
| Lmutlet (_, _, lam, _)
| Lifthenelse (lam, _, _, _)
| Lstaticcatch (lam, _, _, _)
| Lstaticcatch (lam, _, _, _, _)
| Lstaticraise (_, lam :: _)
| Lwhile { wh_cond = lam; _ }
| Lsequence (lam, _)
Expand Down Expand Up @@ -347,6 +345,7 @@ let rec_catch_for_while_loop env cond body =
Lsequence (body, Lstaticraise (cont, [])),
Lconst (Const_base (Const_int 0)),
Lambda.layout_unit ) ),
Same_region,
Lambda.layout_unit )
in
env, lam
Expand Down Expand Up @@ -398,6 +397,7 @@ let rec_catch_for_for_loop env loc ident start stop
Lstaticraise (cont, [next_value_of_counter]),
L.lambda_unit,
Lambda.layout_unit ) ),
Same_region,
Lambda.layout_unit ),
L.lambda_unit,
Lambda.layout_unit ) ) )
Expand All @@ -424,7 +424,8 @@ let let_cont_nonrecursive_with_extra_params acc env ccenv ~is_exn_handler
=
let cont = Continuation.create () in
let { Env.body_env; handler_env; extra_params } =
Env.add_continuation env cont ~push_to_try_stack:is_exn_handler Nonrecursive
Env.add_continuation env cont ~push_to_try_stack:is_exn_handler
~pop_region:false Nonrecursive
in
let handler_env, params_rev =
List.fold_left
Expand Down Expand Up @@ -1036,12 +1037,16 @@ let rec cps acc env ccenv (lam : L.lambda) (k : cps_continuation)
compile_staticfail acc env ccenv ~continuation
~args:(List.flatten args @ extra_args))
k_exn
| Lstaticcatch (body, (static_exn, args), handler, layout) ->
| Lstaticcatch (body, (static_exn, args), handler, r, layout) ->
maybe_insert_let_cont "staticcatch_result" layout k acc env ccenv
(fun acc env ccenv k ->
let pop_region =
match r with Popped_region -> true | Same_region -> false
in
let continuation = Continuation.create () in
let { Env.body_env; handler_env; extra_params } =
Env.add_static_exn_continuation env static_exn continuation
Env.add_static_exn_continuation env static_exn ~pop_region
continuation
in
let recursive : Asttypes.rec_flag =
if Env.is_static_exn_recursive env static_exn
Expand Down
24 changes: 16 additions & 8 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,18 @@ type add_continuation_result =
extra_params : (Ident.t * Flambda_kind.With_subkind.t) list
}

let add_continuation t cont ~push_to_try_stack (recursive : Asttypes.rec_flag) =
let add_continuation t cont ~push_to_try_stack ~pop_region
(recursive : Asttypes.rec_flag) =
let region_stack =
if pop_region
then
match t.region_stack with
| [] -> Misc.fatal_error "Cannot pop region, region stack is empty"
| _ :: region_stack -> region_stack
else t.region_stack
in
let region_stack_in_cont_scope =
Continuation.Map.add cont t.region_stack t.region_stack_in_cont_scope
Continuation.Map.add cont region_stack t.region_stack_in_cont_scope
in
let body_env =
let mutables_needed_by_continuations =
Expand Down Expand Up @@ -171,7 +180,8 @@ let add_continuation t cont ~push_to_try_stack (recursive : Asttypes.rec_flag) =
{ handler_env with
current_values_of_mutables_in_scope;
unboxed_product_components_in_scope;
region_stack_in_cont_scope
region_stack_in_cont_scope;
region_stack
}
in
let extra_params_for_unboxed_products =
Expand All @@ -184,23 +194,21 @@ let add_continuation t cont ~push_to_try_stack (recursive : Asttypes.rec_flag) =
in
{ body_env; handler_env; extra_params }

let add_static_exn_continuation t static_exn cont =
let add_static_exn_continuation t static_exn ~pop_region cont =
let t =
{ t with
try_stack_at_handler =
Continuation.Map.add cont t.try_stack t.try_stack_at_handler;
static_exn_continuation =
Numeric_types.Int.Map.add static_exn cont t.static_exn_continuation;
region_stack_in_cont_scope =
Continuation.Map.add cont t.region_stack t.region_stack_in_cont_scope
Numeric_types.Int.Map.add static_exn cont t.static_exn_continuation
}
in
let recursive : Asttypes.rec_flag =
if Numeric_types.Int.Set.mem static_exn t.recursive_static_catches
then Recursive
else Nonrecursive
in
add_continuation t cont ~push_to_try_stack:false recursive
add_continuation t cont ~push_to_try_stack:false ~pop_region recursive

let get_static_exn_continuation t static_exn =
match Numeric_types.Int.Map.find static_exn t.static_exn_continuation with
Expand Down
3 changes: 2 additions & 1 deletion middle_end/flambda2/from_lambda/lambda_to_flambda_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ val add_continuation :
t ->
Continuation.t ->
push_to_try_stack:bool ->
pop_region:bool ->
Asttypes.rec_flag ->
add_continuation_result

val add_static_exn_continuation :
t -> int -> Continuation.t -> add_continuation_result
t -> int -> pop_region:bool -> Continuation.t -> add_continuation_result

val get_static_exn_continuation : t -> int -> Continuation.t

Expand Down
2 changes: 1 addition & 1 deletion ocaml/bytecomp/bytegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ let rec comp_expr stack_info env exp sz cont =
let nargs = List.length args - 1 in
comp_args stack_info env args sz
(comp_primitive stack_info p (sz + nargs - 1) args :: cont)
| Lstaticcatch (body, (i, vars) , handler, _) ->
| Lstaticcatch (body, (i, vars) , handler, _, _) ->
let vars = List.map fst vars in
let nvars = List.length vars in
let branch1, cont1 = make_branch cont in
Expand Down
26 changes: 16 additions & 10 deletions ocaml/lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,10 @@ type lparam = {
mode : alloc_mode
}

type pop_region =
| Popped_region
| Same_region

type lambda =
Lvar of Ident.t
| Lmutvar of Ident.t
Expand All @@ -706,7 +710,9 @@ type lambda =
| Lstringswitch of
lambda * (string * lambda) list * lambda option * scoped_location * layout
| Lstaticraise of static_label * lambda list
| Lstaticcatch of lambda * (static_label * (Ident.t * layout) list) * lambda * layout
| Lstaticcatch of
lambda * (static_label * (Ident.t * layout) list) * lambda
* pop_region * layout
| Ltrywith of lambda * Ident.t * lambda * layout
| Lifthenelse of lambda * lambda * lambda * layout
| Lsequence of lambda * lambda
Expand Down Expand Up @@ -971,8 +977,8 @@ let make_key e =
Loc_unknown,kind)
| Lstaticraise (i,es) ->
Lstaticraise (i,tr_recs env es)
| Lstaticcatch (e1,xs,e2, kind) ->
Lstaticcatch (tr_rec env e1,xs,tr_rec env e2, kind)
| Lstaticcatch (e1,xs,e2, r, kind) ->
Lstaticcatch (tr_rec env e1,xs,tr_rec env e2, r, kind)
| Ltrywith (e1,x,e2,kind) ->
Ltrywith (tr_rec env e1,x,tr_rec env e2,kind)
| Lifthenelse (cond,ifso,ifnot,kind) ->
Expand Down Expand Up @@ -1064,7 +1070,7 @@ let shallow_iter ~tail ~non_tail:f = function
iter_opt tail default
| Lstaticraise (_,args) ->
List.iter f args
| Lstaticcatch(e1, _, e2, _kind) ->
| Lstaticcatch(e1, _, e2, _, _kind) ->
tail e1; tail e2
| Ltrywith(e1, _, e2,_) ->
f e1; tail e2
Expand Down Expand Up @@ -1137,7 +1143,7 @@ let rec free_variables = function
end
| Lstaticraise (_,args) ->
free_variables_list Ident.Set.empty args
| Lstaticcatch(body, (_, params), handler, _kind) ->
| Lstaticcatch(body, (_, params), handler, _, _kind) ->
Ident.Set.union
(Ident.Set.diff
(free_variables handler)
Expand Down Expand Up @@ -1359,10 +1365,10 @@ let build_substs update_env ?(freshen_bound_variables = false) s =
subst_opt s l default,
loc,kind)
| Lstaticraise (i,args) -> Lstaticraise (i, subst_list s l args)
| Lstaticcatch(body, (id, params), handler, kind) ->
| Lstaticcatch(body, (id, params), handler, r, kind) ->
let params, l' = bind_many params l in
Lstaticcatch(subst s l body, (id, params),
subst s l' handler, kind)
subst s l' handler, r, kind)
| Ltrywith(body, exn, handler,kind) ->
let exn, l' = bind exn l in
Ltrywith(subst s l body, exn, subst s l' handler,kind)
Expand Down Expand Up @@ -1506,8 +1512,8 @@ let shallow_map ~tail ~non_tail:f = function
loc, layout)
| Lstaticraise (i, args) ->
Lstaticraise (i, List.map f args)
| Lstaticcatch (body, id, handler, layout) ->
Lstaticcatch (tail body, id, tail handler, layout)
| Lstaticcatch (body, id, handler, r, layout) ->
Lstaticcatch (tail body, id, tail handler, r, layout)
| Ltrywith (e1, v, e2, layout) ->
Ltrywith (f e1, v, tail e2, layout)
| Lifthenelse (e1, e2, e3, layout) ->
Expand Down Expand Up @@ -1974,7 +1980,7 @@ let compute_expr_layout free_vars_kind lam =
| Lprim(p, _, _) ->
primitive_result_layout p
| Lswitch(_, _, _, kind) | Lstringswitch(_, _, _, _, kind)
| Lstaticcatch(_, _, _, kind) | Ltrywith(_, _, _, kind)
| Lstaticcatch(_, _, _, _, kind) | Ltrywith(_, _, _, kind)
| Lifthenelse(_, _, _, kind) | Lregion (_, kind) ->
kind
| Lstaticraise (_, _) ->
Expand Down
22 changes: 21 additions & 1 deletion ocaml/lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ type lparam = {

type scoped_location = Debuginfo.Scoped_location.t

type pop_region =
| Popped_region
| Same_region

type lambda =
Lvar of Ident.t
| Lmutvar of Ident.t
Expand All @@ -598,7 +602,21 @@ type lambda =
| Lstringswitch of
lambda * (string * lambda) list * lambda option * scoped_location * layout
| Lstaticraise of static_label * lambda list
| Lstaticcatch of lambda * (static_label * (Ident.t * layout) list) * lambda * layout
(* Concerning [Lstaticcatch], the regions that are open in the handler must be
a subset of those open at the point of the [Lstaticraise] that jumps to it,
as we can't reopen closed regions. All regions that were open at the point of
the [Lstaticraise] but not in the handler will be closed just before the [Lstaticraise].

However, to be able to express the fact
that the [Lstaticraise] might be under a [Lexclave], the [pop_region] flag
is used to specify what regions are considered open in the handler. If it
is [Same_region], it means that the same regions as those existing at the
point of the [Lstaticraise] are considered open in the handler; if it is [Popped_region],
it means that we consider the top region at the point of the [Lstaticcatch] to not be
considered open inside the handler. *)
| Lstaticcatch of
lambda * (static_label * (Ident.t * layout) list) * lambda
* pop_region * layout
| Ltrywith of lambda * Ident.t * lambda * layout
(* Lifthenelse (e, t, f, layout) evaluates t if e evaluates to 0, and evaluates f if
e evaluates to any other value; layout must be the layout of [t] and [f] *)
Expand All @@ -612,6 +630,8 @@ type lambda =
| Levent of lambda * lambda_event
| Lifused of Ident.t * lambda
| Lregion of lambda * layout
(* [Lexclave] closes the newest region opened.
Note that [Lexclave] nesting is currently unsupported. *)
| Lexclave of lambda

and rec_binding = {
Expand Down
26 changes: 16 additions & 10 deletions ocaml/lambda/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ let make_catch kind d k =
| Lstaticraise (_, []) -> k d
| _ ->
let e = next_raise_count () in
Lstaticcatch (k (make_exit e), (e, []), d, kind)
Lstaticcatch (k (make_exit e), (e, []), d, Same_region, kind)

(* Introduce a catch, if worth it, delayed version *)
let rec as_simple_exit = function
Expand All @@ -1078,7 +1078,7 @@ let make_catch_delayed kind handler =
handler
else
body
| _ -> Lstaticcatch (body, (i, []), handler, kind) )
| _ -> Lstaticcatch (body, (i, []), handler, Same_region, kind) )
)

let raw_action l =
Expand Down Expand Up @@ -3302,7 +3302,9 @@ let compile_orhandlers value_kind compile_fun lambda1 total1 ctx to_catch =
(* Whilst the handler is [lambda_unit] it is actually unused and only added
to produce well-formed code. In reality this expression returns a
[value_kind]. *)
do_rec (Lstaticcatch (r, (i, vars), lambda_unit, value_kind)) total_r rem
do_rec
(Lstaticcatch (r, (i, vars), lambda_unit, Same_region, value_kind))
total_r rem
| handler_i, total_i ->
begin match raw_action r with
| Lstaticraise (j, args) ->
Expand All @@ -3315,7 +3317,8 @@ let compile_orhandlers value_kind compile_fun lambda1 total1 ctx to_catch =
do_rec r total_r rem
| _ ->
do_rec
(Lstaticcatch (r, (i, vars), handler_i, value_kind))
(Lstaticcatch
(r, (i, vars), handler_i, Same_region, value_kind))
(Jumps.union (Jumps.remove i total_r)
(Jumps.map (Context.rshift_num (ncols mat)) total_i))
rem
Expand Down Expand Up @@ -3409,15 +3412,16 @@ let rec comp_match_handlers value_kind comp_fun partial ctx first_match next_mat
match comp_fun partial ctx_i pm with
| li, total_i ->
c_rec
(Lstaticcatch (body, (i, []), li, value_kind))
(Lstaticcatch (body, (i, []), li, Same_region, value_kind))
(Jumps.union total_i total_rem)
rem
| exception Unused ->
(* Whilst the handler is [lambda_unit] it is actually unused and only added
to produce well-formed code. In reality this expression returns a
[value_kind]. *)
c_rec
(Lstaticcatch (body, (i, []), lambda_unit, value_kind))
(Lstaticcatch
(body, (i, []), lambda_unit, Same_region, value_kind))
total_rem rem
end
)
Expand Down Expand Up @@ -3741,7 +3745,8 @@ let check_total ~scopes value_kind loc ~failer total lambda i =
lambda
else
Lstaticcatch (lambda, (i, []),
failure_handler ~scopes loc ~failer (), value_kind)
failure_handler ~scopes loc ~failer (),
Same_region, value_kind)

let toplevel_handler ~scopes ~return_layout loc ~failer partial args cases compile_fun =
match partial with
Expand Down Expand Up @@ -3854,8 +3859,8 @@ let rec map_return f = function
| Lsequence (l1, l2) -> Lsequence (l1, map_return f l2)
| Levent (l, ev) -> Levent (map_return f l, ev)
| Ltrywith (l1, id, l2, k) -> Ltrywith (map_return f l1, id, map_return f l2, k)
| Lstaticcatch (l1, b, l2, k) ->
Lstaticcatch (map_return f l1, b, map_return f l2, k)
| Lstaticcatch (l1, b, l2, r, k) ->
Lstaticcatch (map_return f l1, b, map_return f l2, r, k)
| Lswitch (s, sw, loc, k) ->
let map_cases cases =
List.map (fun (i, l) -> (i, map_return f l)) cases
Expand Down Expand Up @@ -3965,7 +3970,8 @@ let for_let ~scopes ~arg_sort ~return_layout loc param pat body =
param
in
if !opt then
Lstaticcatch (bind, (nraise, ids_with_kinds), body, return_layout)
Lstaticcatch
(bind, (nraise, ids_with_kinds), body, Same_region,return_layout)
else
simple_for_let ~scopes ~arg_sort ~return_layout loc param pat body

Expand Down
11 changes: 8 additions & 3 deletions ocaml/lambda/printlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1112,16 +1112,21 @@ let rec lam ppf = function
let lams ppf largs =
List.iter (fun l -> fprintf ppf "@ %a" lam l) largs in
fprintf ppf "@[<2>(exit@ %d%a)@]" i lams ls;
| Lstaticcatch(lbody, (i, vars), lhandler, _kind) ->
fprintf ppf "@[<2>(catch@ %a@;<1 -1>with (%d%a)@ %a)@]"
| Lstaticcatch(lbody, (i, vars), lhandler, r, _kind) ->
let excl =
match r with
| Popped_region -> " exclave"
| Same_region -> ""
in
fprintf ppf "@[<2>(catch@ %a@;<1 -1>with (%d%a)%s@ %a)@]"
lam lbody i
(fun ppf vars ->
List.iter
(fun (x, k) -> fprintf ppf " %a%a" Ident.print x layout k)
vars
)
vars
lam lhandler
excl lam lhandler
| Ltrywith(lbody, param, lhandler, _kind) ->
fprintf ppf "@[<2>(try@ %a@;<1 -1>with %a@ %a)@]"
lam lbody Ident.print param lam lhandler
Expand Down
Loading
Loading