Skip to content

Commit

Permalink
flambda-backend: Fix case where local functions optimisation breaks t…
Browse files Browse the repository at this point in the history
…ail calls (#2360)

* Fix case where local functions optimisation breaks tail calls

* comment

* fix printing

* fix after rebase

* make fmt

* Add comment about nesting of Lexclave

---------

Co-authored-by: Nathanaëlle Courant <nathanaelle.courant@ocamlpro.com>
  • Loading branch information
lpw25 and Ekdohibs authored May 15, 2024
1 parent 2b7bbc0 commit 22911e5
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 55 deletions.
2 changes: 1 addition & 1 deletion bytecomp/bytegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ let rec comp_expr stack_info env exp sz cont =
let nargs = List.length args - 1 in
comp_args stack_info env args sz
(comp_primitive stack_info p (sz + nargs - 1) args :: cont)
| Lstaticcatch (body, (i, vars) , handler, _) ->
| Lstaticcatch (body, (i, vars) , handler, _, _) ->
let vars = List.map fst vars in
let nvars = List.length vars in
let branch1, cont1 = make_branch cont in
Expand Down
26 changes: 16 additions & 10 deletions lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,10 @@ type lparam = {
mode : alloc_mode
}

type pop_region =
| Popped_region
| Same_region

type lambda =
Lvar of Ident.t
| Lmutvar of Ident.t
Expand All @@ -706,7 +710,9 @@ type lambda =
| Lstringswitch of
lambda * (string * lambda) list * lambda option * scoped_location * layout
| Lstaticraise of static_label * lambda list
| Lstaticcatch of lambda * (static_label * (Ident.t * layout) list) * lambda * layout
| Lstaticcatch of
lambda * (static_label * (Ident.t * layout) list) * lambda
* pop_region * layout
| Ltrywith of lambda * Ident.t * lambda * layout
| Lifthenelse of lambda * lambda * lambda * layout
| Lsequence of lambda * lambda
Expand Down Expand Up @@ -971,8 +977,8 @@ let make_key e =
Loc_unknown,kind)
| Lstaticraise (i,es) ->
Lstaticraise (i,tr_recs env es)
| Lstaticcatch (e1,xs,e2, kind) ->
Lstaticcatch (tr_rec env e1,xs,tr_rec env e2, kind)
| Lstaticcatch (e1,xs,e2, r, kind) ->
Lstaticcatch (tr_rec env e1,xs,tr_rec env e2, r, kind)
| Ltrywith (e1,x,e2,kind) ->
Ltrywith (tr_rec env e1,x,tr_rec env e2,kind)
| Lifthenelse (cond,ifso,ifnot,kind) ->
Expand Down Expand Up @@ -1064,7 +1070,7 @@ let shallow_iter ~tail ~non_tail:f = function
iter_opt tail default
| Lstaticraise (_,args) ->
List.iter f args
| Lstaticcatch(e1, _, e2, _kind) ->
| Lstaticcatch(e1, _, e2, _, _kind) ->
tail e1; tail e2
| Ltrywith(e1, _, e2,_) ->
f e1; tail e2
Expand Down Expand Up @@ -1137,7 +1143,7 @@ let rec free_variables = function
end
| Lstaticraise (_,args) ->
free_variables_list Ident.Set.empty args
| Lstaticcatch(body, (_, params), handler, _kind) ->
| Lstaticcatch(body, (_, params), handler, _, _kind) ->
Ident.Set.union
(Ident.Set.diff
(free_variables handler)
Expand Down Expand Up @@ -1359,10 +1365,10 @@ let build_substs update_env ?(freshen_bound_variables = false) s =
subst_opt s l default,
loc,kind)
| Lstaticraise (i,args) -> Lstaticraise (i, subst_list s l args)
| Lstaticcatch(body, (id, params), handler, kind) ->
| Lstaticcatch(body, (id, params), handler, r, kind) ->
let params, l' = bind_many params l in
Lstaticcatch(subst s l body, (id, params),
subst s l' handler, kind)
subst s l' handler, r, kind)
| Ltrywith(body, exn, handler,kind) ->
let exn, l' = bind exn l in
Ltrywith(subst s l body, exn, subst s l' handler,kind)
Expand Down Expand Up @@ -1506,8 +1512,8 @@ let shallow_map ~tail ~non_tail:f = function
loc, layout)
| Lstaticraise (i, args) ->
Lstaticraise (i, List.map f args)
| Lstaticcatch (body, id, handler, layout) ->
Lstaticcatch (tail body, id, tail handler, layout)
| Lstaticcatch (body, id, handler, r, layout) ->
Lstaticcatch (tail body, id, tail handler, r, layout)
| Ltrywith (e1, v, e2, layout) ->
Ltrywith (f e1, v, tail e2, layout)
| Lifthenelse (e1, e2, e3, layout) ->
Expand Down Expand Up @@ -1974,7 +1980,7 @@ let compute_expr_layout free_vars_kind lam =
| Lprim(p, _, _) ->
primitive_result_layout p
| Lswitch(_, _, _, kind) | Lstringswitch(_, _, _, _, kind)
| Lstaticcatch(_, _, _, kind) | Ltrywith(_, _, _, kind)
| Lstaticcatch(_, _, _, _, kind) | Ltrywith(_, _, _, kind)
| Lifthenelse(_, _, _, kind) | Lregion (_, kind) ->
kind
| Lstaticraise (_, _) ->
Expand Down
22 changes: 21 additions & 1 deletion lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ type lparam = {

type scoped_location = Debuginfo.Scoped_location.t

type pop_region =
| Popped_region
| Same_region

type lambda =
Lvar of Ident.t
| Lmutvar of Ident.t
Expand All @@ -598,7 +602,21 @@ type lambda =
| Lstringswitch of
lambda * (string * lambda) list * lambda option * scoped_location * layout
| Lstaticraise of static_label * lambda list
| Lstaticcatch of lambda * (static_label * (Ident.t * layout) list) * lambda * layout
(* Concerning [Lstaticcatch], the regions that are open in the handler must be
a subset of those open at the point of the [Lstaticraise] that jumps to it,
as we can't reopen closed regions. All regions that were open at the point of
the [Lstaticraise] but not in the handler will be closed just before the [Lstaticraise].
However, to be able to express the fact
that the [Lstaticraise] might be under a [Lexclave], the [pop_region] flag
is used to specify what regions are considered open in the handler. If it
is [Same_region], it means that the same regions as those existing at the
point of the [Lstaticraise] are considered open in the handler; if it is [Popped_region],
it means that we consider the top region at the point of the [Lstaticcatch] to not be
considered open inside the handler. *)
| Lstaticcatch of
lambda * (static_label * (Ident.t * layout) list) * lambda
* pop_region * layout
| Ltrywith of lambda * Ident.t * lambda * layout
(* Lifthenelse (e, t, f, layout) evaluates t if e evaluates to 0, and evaluates f if
e evaluates to any other value; layout must be the layout of [t] and [f] *)
Expand All @@ -612,6 +630,8 @@ type lambda =
| Levent of lambda * lambda_event
| Lifused of Ident.t * lambda
| Lregion of lambda * layout
(* [Lexclave] closes the newest region opened.
Note that [Lexclave] nesting is currently unsupported. *)
| Lexclave of lambda

and rec_binding = {
Expand Down
26 changes: 16 additions & 10 deletions lambda/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ let make_catch kind d k =
| Lstaticraise (_, []) -> k d
| _ ->
let e = next_raise_count () in
Lstaticcatch (k (make_exit e), (e, []), d, kind)
Lstaticcatch (k (make_exit e), (e, []), d, Same_region, kind)

(* Introduce a catch, if worth it, delayed version *)
let rec as_simple_exit = function
Expand All @@ -1078,7 +1078,7 @@ let make_catch_delayed kind handler =
handler
else
body
| _ -> Lstaticcatch (body, (i, []), handler, kind) )
| _ -> Lstaticcatch (body, (i, []), handler, Same_region, kind) )
)

let raw_action l =
Expand Down Expand Up @@ -3302,7 +3302,9 @@ let compile_orhandlers value_kind compile_fun lambda1 total1 ctx to_catch =
(* Whilst the handler is [lambda_unit] it is actually unused and only added
to produce well-formed code. In reality this expression returns a
[value_kind]. *)
do_rec (Lstaticcatch (r, (i, vars), lambda_unit, value_kind)) total_r rem
do_rec
(Lstaticcatch (r, (i, vars), lambda_unit, Same_region, value_kind))
total_r rem
| handler_i, total_i ->
begin match raw_action r with
| Lstaticraise (j, args) ->
Expand All @@ -3315,7 +3317,8 @@ let compile_orhandlers value_kind compile_fun lambda1 total1 ctx to_catch =
do_rec r total_r rem
| _ ->
do_rec
(Lstaticcatch (r, (i, vars), handler_i, value_kind))
(Lstaticcatch
(r, (i, vars), handler_i, Same_region, value_kind))
(Jumps.union (Jumps.remove i total_r)
(Jumps.map (Context.rshift_num (ncols mat)) total_i))
rem
Expand Down Expand Up @@ -3409,15 +3412,16 @@ let rec comp_match_handlers value_kind comp_fun partial ctx first_match next_mat
match comp_fun partial ctx_i pm with
| li, total_i ->
c_rec
(Lstaticcatch (body, (i, []), li, value_kind))
(Lstaticcatch (body, (i, []), li, Same_region, value_kind))
(Jumps.union total_i total_rem)
rem
| exception Unused ->
(* Whilst the handler is [lambda_unit] it is actually unused and only added
to produce well-formed code. In reality this expression returns a
[value_kind]. *)
c_rec
(Lstaticcatch (body, (i, []), lambda_unit, value_kind))
(Lstaticcatch
(body, (i, []), lambda_unit, Same_region, value_kind))
total_rem rem
end
)
Expand Down Expand Up @@ -3741,7 +3745,8 @@ let check_total ~scopes value_kind loc ~failer total lambda i =
lambda
else
Lstaticcatch (lambda, (i, []),
failure_handler ~scopes loc ~failer (), value_kind)
failure_handler ~scopes loc ~failer (),
Same_region, value_kind)

let toplevel_handler ~scopes ~return_layout loc ~failer partial args cases compile_fun =
match partial with
Expand Down Expand Up @@ -3854,8 +3859,8 @@ let rec map_return f = function
| Lsequence (l1, l2) -> Lsequence (l1, map_return f l2)
| Levent (l, ev) -> Levent (map_return f l, ev)
| Ltrywith (l1, id, l2, k) -> Ltrywith (map_return f l1, id, map_return f l2, k)
| Lstaticcatch (l1, b, l2, k) ->
Lstaticcatch (map_return f l1, b, map_return f l2, k)
| Lstaticcatch (l1, b, l2, r, k) ->
Lstaticcatch (map_return f l1, b, map_return f l2, r, k)
| Lswitch (s, sw, loc, k) ->
let map_cases cases =
List.map (fun (i, l) -> (i, map_return f l)) cases
Expand Down Expand Up @@ -3965,7 +3970,8 @@ let for_let ~scopes ~arg_sort ~return_layout loc param pat body =
param
in
if !opt then
Lstaticcatch (bind, (nraise, ids_with_kinds), body, return_layout)
Lstaticcatch
(bind, (nraise, ids_with_kinds), body, Same_region,return_layout)
else
simple_for_let ~scopes ~arg_sort ~return_layout loc param pat body

Expand Down
11 changes: 8 additions & 3 deletions lambda/printlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1112,16 +1112,21 @@ let rec lam ppf = function
let lams ppf largs =
List.iter (fun l -> fprintf ppf "@ %a" lam l) largs in
fprintf ppf "@[<2>(exit@ %d%a)@]" i lams ls;
| Lstaticcatch(lbody, (i, vars), lhandler, _kind) ->
fprintf ppf "@[<2>(catch@ %a@;<1 -1>with (%d%a)@ %a)@]"
| Lstaticcatch(lbody, (i, vars), lhandler, r, _kind) ->
let excl =
match r with
| Popped_region -> " exclave"
| Same_region -> ""
in
fprintf ppf "@[<2>(catch@ %a@;<1 -1>with (%d%a)%s@ %a)@]"
lam lbody i
(fun ppf vars ->
List.iter
(fun (x, k) -> fprintf ppf " %a%a" Ident.print x layout k)
vars
)
vars
lam lhandler
excl lam lhandler
| Ltrywith(lbody, param, lhandler, _kind) ->
fprintf ppf "@[<2>(try@ %a@;<1 -1>with %a@ %a)@]"
lam lbody Ident.print param lam lhandler
Expand Down
Loading

0 comments on commit 22911e5

Please sign in to comment.