Skip to content

Commit

Permalink
Rewrite let x = M in tail x as M
Browse files Browse the repository at this point in the history
This should get rid of some annoying instances of extra variables introduced
just to lift, and in particular stop making tail calls into non-tail calls.
  • Loading branch information
lukemaurer committed Apr 18, 2023
1 parent 6126e9e commit dccb971
Showing 1 changed file with 19 additions and 25 deletions.
44 changes: 19 additions & 25 deletions middle_end/flambda/lift_code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,23 @@ let rebuild_let (defs : def list) (body : Flambda.t) =
Flambda.Tail body)
body defs

(* Given something like [x = M; y = N], where we intend to produce
[let x = M in let y = N in y], try and make [let x = M in N] instead by
splitting out the innermost definition. If this isn't possible, say in the
case [x = M; Region r; y = N; Tail r], instead just return the variable and
the whole definition list. *)
let split_defs defs var =
let use_var () =
defs, Flambda.Var var
in
(* Given something like [y = N; x = M], where we intend to produce [let x = M in
let y = N in y], try and make [let x = M in N] instead by splitting out the
innermost definition. If this isn't possible, instead just return the
variable and the whole definition list. *)
let rec split_defs defs var =
match defs with
| Immutable (var', named) :: defs' when Variable.equal var var' -> begin
match Flambda.With_free_variables.contents named with
| Expr expr -> defs', expr
| _ -> use_var ()
| _ -> defs, Flambda.Var var
end
| Immutable (var', _) :: _ ->
Misc.fatal_errorf "Expected binding for %a@ but found %a"
Variable.print var Variable.print var'
| Tail :: _ ->
(* The typechecker ensures that the constructed value can escape safely *)
use_var ()
| Tail :: defs' ->
(* Rewrite [let x = M in tail x] as simply M *)
split_defs defs' var
| (Mutable _ | Region) :: _ | [] ->
Misc.fatal_errorf "Expected binding for %a"
Variable.print var
Expand Down Expand Up @@ -172,7 +168,7 @@ and extract_region acc dest body =
| Expr expr -> tail_expr_in_expr expr
| _ -> false
in
begin if cannot_lift then
if cannot_lift then
(* There's a tail expression that we can't lift out, so we can't do anything
but bundle everything back up in a region. *)
let expr =
Expand All @@ -181,23 +177,22 @@ and extract_region acc dest body =
in
Immutable(dest, W.expr expr) :: acc
else
(* If possible, recover the expression that gets assigned to [inner_dest] so
we can directly assign [dest] to it instead *)
let acc_expr, body = split_defs acc_expr inner_dest in
(* The accumulator must remain balanced between [Region] and [Tail], since
it defines a scope into which [extract_let_expr] will move arbitrary
computations - if there is a [Region] but no [Tail], this means we're
moving those computations into a different region. It may be that
[acc_expr] already has a [Tail] (because we lifted it out of [body]), but
otherwise we need to add it. *)
let need_tail = not (defs_close_region acc_expr) in
(* If possible, recover the expression that gets assigned to [inner_dest] so
we can directly assign [dest] to it instead *)
let acc_expr, body = split_defs acc_expr inner_dest in
List.concat
[ if need_tail then [ Tail ] else [];
[ Immutable (dest, W.expr (W.of_expr body)) ];
acc_expr;
[ Region ];
acc ]
end

and extract_tail_call acc dest (apply : Flambda.apply) =
let module W = Flambda.With_free_variables in
Expand Down Expand Up @@ -244,19 +239,18 @@ and extract acc dest expr =
Immutable (dest, W.expr expr) :: acc

let rec lift_lets_expr (expr:Flambda.t) ~toplevel : Flambda.t =
let module W = Flambda.With_free_variables in
match expr with
| Let let_expr ->
let dest = Variable.create Internal_variable_names.lifted_let in
let defs = extract_let_expr [] dest let_expr in
let rev_defs = List.rev_map (lift_lets_def ~toplevel) defs in
rebuild_expr (List.rev rev_defs) dest
let defs = List.map (lift_lets_def ~toplevel) defs in
rebuild_expr defs dest
| Let_mutable let_mut ->
let dest = Variable.create Internal_variable_names.lifted_let in
let defs =
extract_let_mutable [] dest let_mut
in
let rev_defs = List.rev_map (lift_lets_def ~toplevel) defs in
rebuild_expr (List.rev rev_defs) dest
let defs = extract_let_mutable [] dest let_mut in
let defs = List.map (lift_lets_def ~toplevel) defs in
rebuild_expr defs dest
| e ->
Flambda_iterators.map_subexpressions
(lift_lets_expr ~toplevel)
Expand Down

0 comments on commit dccb971

Please sign in to comment.