Skip to content

Commit

Permalink
Don't try to be clever with let x = M in tail x
Browse files Browse the repository at this point in the history
There's nothing actually wrong with `let x = M in tail x`, so don't try to
reduce it to just `M`. This works in that narrow case, but the code that was
doing this transformation didn't notice if the body of the tail is more than
just `x`. Since the transformation doesn't actually gain anything, better to
be rid of it than to make things more complicated trying to get it right.

Also updated a few comments and added an invariant check.
  • Loading branch information
lukemaurer committed May 10, 2023
1 parent 6fa1ce6 commit 8a717aa
Showing 1 changed file with 52 additions and 33 deletions.
85 changes: 52 additions & 33 deletions middle_end/flambda/lift_code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,37 @@ let rebuild_let (defs : def list) (body : Flambda.t) =
Flambda.Tail body)
body defs

(* Given something like [y = N; x = M], where we intend to produce [let x = M in
let y = N in y], try and make [let x = M in N] instead by splitting out the
innermost definition. If this isn't possible, instead just return the
variable and the whole definition list. *)
let rec split_defs defs var =
(* Given something like [[x = M; y = N]], where we intend to produce [let y = N in
let x = M in x], try and make [let y = N in M] instead. If this isn't possible,
instead just return the input unchanged.
More concretely, this function views the given defs and variable as a term
[D[x]] where [D] the context represented by the defs. For example, in the above,
we think of the inputs [[x = M; y = N]] and [x] as representing
[let y = N in let x = M in x]. Then we'll rewrite this term as [let y = N in M]
by returning [[y = N]] and [M]. In some cases, no such rewrite is
possible, such as with [[tail; x = M; y = N]] and [x], representing
[let y = N in let x = M in tail x]. In this case, we simply return the defs as
given along with [x] (as an expression).
Note that it's tempting to play games like turning [let x = M in tail x] into
simply [M]. This is valid in that particular case, but it doesn't actually
win anything and it produces defs that are dangerous to use for anything but
wrapping exactly the returned expression. In particular, the defs may be
unbalanced, leaving a region open. *)
let split_defs defs var : def list * Flambda.expr =
let module W = Flambda.With_free_variables in
match defs with
| Immutable (var', named) :: defs' when Variable.equal var var' -> begin
match W.contents named with
| Expr expr -> defs', expr
| _ -> defs, Flambda.Var var
| _ -> defs, Var var
end
| Immutable (var', _) :: _ ->
Misc.fatal_errorf "Expected binding for %a@ but found %a"
Variable.print var Variable.print var'
| Tail :: defs' ->
(* Rewrite [let x = M in tail x] as simply [M] *)
split_defs defs' var
| Tail :: _ ->
defs, Var var
| (Mutable _ | Region) :: _ | [] ->
Misc.fatal_errorf "Expected binding for %a"
Variable.print var
Expand All @@ -65,16 +78,24 @@ let rebuild_expr defs var =
let defs, body = split_defs defs var in
rebuild_let defs body

let region_delta defs =
List.fold_left
(fun delta def ->
match def with
| Tail -> delta - 1
| Region -> delta + 1
| Immutable _ | Mutable _ -> delta)
0 defs

let defs_open_region defs =
region_delta defs > 0

let defs_close_region defs =
let rec more_tails_than_regions defs tails regions =
match defs with
| [] -> tails > regions
| Tail :: defs -> more_tails_than_regions defs (tails + 1) regions
| Region :: defs -> more_tails_than_regions defs tails (regions + 1)
| (Immutable _ | Mutable _) :: defs ->
more_tails_than_regions defs tails regions
in
more_tails_than_regions defs 0 0
region_delta defs < 0

let check_defs defs =
if !Clflags.flambda_invariant_checks then
assert (not (defs_open_region defs))

let rec tail_expr_in_expr0 (expr : Flambda.t) ~depth =
match expr with
Expand Down Expand Up @@ -169,12 +190,12 @@ and extract_region acc dest body =
we can directly assign [dest] to it instead *)
match split_defs acc_expr inner_dest with
| acc_expr, body when liftable_region_body body ->
(* The accumulator must remain balanced between [Region] and [Tail], since
it defines a scope into which [extract_let_expr] will move arbitrary
computations - if there is a [Region] but no [Tail], this means we're
moving those computations into a different region. It may be that
[acc_expr] already has a [Tail] (because we lifted it out of [body]), but
otherwise we need to add it. *)
(* The accumulator must remain balanced between [Region] and [Tail] (see
[check_defs]), since it defines a scope into which [extract_let_expr]
will move arbitrary computations - if there is a [Region] but no [Tail],
this means we're moving those computations into a different region. It
may be that [acc_expr] already has a [Tail] (because we lifted it out of
[body]), but otherwise we need to add it. *)
let need_tail = not (defs_close_region acc_expr) in
List.concat
[ if need_tail then [ Tail ] else [];
Expand All @@ -191,15 +212,8 @@ and extract_region acc dest body =
and extract_tail_call acc dest (apply : Flambda.apply) =
let module W = Flambda.With_free_variables in
(* Rewrite a close-at-apply call as a normal call in a [Tail] so that we can
float the [Tail]. Note that it will still be a tail call in the normal
sense (it replaces the old stack frame, etc.), it just no longer has the
additional semantics of ending the region first. *)
let apply =
(* We can safely assume [Rc_normal] makes sense here because the original
application was marked [Rc_close_at_apply], so it must be intended to be
a tail call and marking it [Rc_nontail] would be silly. *)
{ apply with reg_close = Rc_normal }
in
float the [Tail] *)
let apply = { apply with reg_close = Rc_normal } in
Immutable (dest, W.expr (W.of_expr (Apply apply))) :: Tail :: acc

and extract_tail_send acc dest (send : Flambda.send) =
Expand All @@ -210,6 +224,7 @@ and extract_tail_send acc dest (send : Flambda.send) =

and extract acc dest expr =
let module W = Flambda.With_free_variables in
check_defs acc;
match (W.contents expr : Flambda.t) with
| Let let_expr ->
extract_let_expr acc dest let_expr
Expand Down Expand Up @@ -243,6 +258,10 @@ and extract acc dest expr =
let rec lift_lets_expr (expr:Flambda.t) ~toplevel : Flambda.t =
match expr with
| Let let_expr ->
(* For uniformity, wrap everything in another [let] binding, which
[rebuild_expr] will try to eliminate. Sometimes we can't eliminate it
easily (see comments on [split_defs]), but it's harmless and not worth
the complexity to avoid it. *)
let dest = Variable.create Internal_variable_names.lifted_let in
let defs = extract_let_expr [] dest let_expr in
let rev_defs = List.rev_map (lift_lets_def ~toplevel) defs in
Expand Down

0 comments on commit 8a717aa

Please sign in to comment.