Skip to content

Fix alloc modes and call kinds for overapplications #902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions middle_end/flambda2/from_lambda/closure_conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1823,12 +1823,10 @@ let wrap_over_application acc env full_call (apply : IR.apply) over_args
| Rc_normal | Rc_close_at_apply -> Apply.Position.Normal
| Rc_nontail -> Apply.Position.Nontail
in
let alloc_mode =
if contains_no_escaping_local_allocs
then Alloc_mode.For_types.heap
else Alloc_mode.For_types.unknown ()
let call_kind =
Call_kind.indirect_function_call_unknown_arity
(Alloc_mode.For_types.from_lambda apply.mode)
in
let call_kind = Call_kind.indirect_function_call_unknown_arity alloc_mode in
let continuation =
match needs_region with
| None -> apply_return_continuation
Expand Down Expand Up @@ -1967,8 +1965,13 @@ let close_apply acc env (apply : IR.apply) : Expr_with_acc.t =
~contains_no_escaping_local_allocs
| Over_app (args, remaining_args) ->
let full_args_call apply_continuation ~region acc =
let mode =
if contains_no_escaping_local_allocs
then Lambda.alloc_heap
else Lambda.alloc_local
in
close_exact_or_unknown_apply acc env
{ apply with args; continuation = apply_continuation }
{ apply with args; continuation = apply_continuation; mode }
(Some approx) ~replace_region:(Some region)
in
wrap_over_application acc env full_args_call apply remaining_args
Expand Down
48 changes: 22 additions & 26 deletions middle_end/flambda2/simplify/simplify_apply_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ let record_free_names_of_apply_as_used0 apply data_flow =
let record_free_names_of_apply_as_used dacc apply =
DA.map_data_flow dacc ~f:(record_free_names_of_apply_as_used0 apply)

let simplify_direct_tuple_application ~simplify_expr dacc apply
~params_arity:param_arity ~result_arity ~apply_alloc_mode
~contains_no_escaping_local_allocs ~current_region ~down_to_up =
let simplify_direct_tuple_application ~simplify_expr dacc apply ~result_arity
~apply_alloc_mode ~current_region ~callee's_code_id ~callee's_code_metadata
~down_to_up =
let dbg = Apply.dbg apply in
let n = Flambda_arity.With_subkinds.cardinal param_arity in
let n =
Flambda_arity.With_subkinds.cardinal
(Code_metadata.params_arity callee's_code_metadata)
in
(* Split the tuple argument from other potential over application arguments *)
let tuple, over_application_args =
match Apply.args apply with
Expand Down Expand Up @@ -80,9 +83,9 @@ let simplify_direct_tuple_application ~simplify_expr dacc apply
(* [apply] already got a correct relative_history and
[split_direct_over_application] infers the relative history from the
one on [apply] so there's nothing to do here. *)
Simplify_common.split_direct_over_application apply ~param_arity
~result_arity ~apply_alloc_mode ~contains_no_escaping_local_allocs
~current_region
Simplify_common.split_direct_over_application apply ~result_arity
~apply_alloc_mode ~current_region ~callee's_code_id
~callee's_code_metadata
in
(* Insert the projections and simplify the new expression, to allow field
projections to be simplified, and over-application/full_application
Expand Down Expand Up @@ -579,14 +582,14 @@ let simplify_direct_partial_application ~simplify_expr dacc apply
in
simplify_expr dacc expr ~down_to_up

let simplify_direct_over_application ~simplify_expr dacc apply ~param_arity
~result_arity ~down_to_up ~coming_from_indirect ~apply_alloc_mode
~contains_no_escaping_local_allocs ~current_region =
let simplify_direct_over_application ~simplify_expr dacc apply ~result_arity
~down_to_up ~coming_from_indirect ~apply_alloc_mode ~current_region
~callee's_code_id ~callee's_code_metadata =
fail_if_probe apply;
let expr =
Simplify_common.split_direct_over_application apply ~param_arity
~result_arity ~apply_alloc_mode ~contains_no_escaping_local_allocs
~current_region
Simplify_common.split_direct_over_application apply ~result_arity
~apply_alloc_mode ~current_region ~callee's_code_id
~callee's_code_metadata
in
let down_to_up dacc ~rebuild =
let rebuild uacc ~after_rebuild =
Expand Down Expand Up @@ -669,12 +672,9 @@ let simplify_direct_function_call ~simplify_expr dacc apply
tuple argument, irrespective of what [Code.params_arity] says. *)
if must_be_detupled
then
simplify_direct_tuple_application ~simplify_expr dacc apply ~params_arity
~result_arity ~apply_alloc_mode
~contains_no_escaping_local_allocs:
(Code_metadata.contains_no_escaping_local_allocs
callee's_code_metadata)
~current_region ~down_to_up
simplify_direct_tuple_application ~simplify_expr dacc apply ~result_arity
~apply_alloc_mode ~current_region ~callee's_code_id
~callee's_code_metadata ~down_to_up
else
let args = Apply.args apply in
let provided_num_args = List.length args in
Expand All @@ -686,13 +686,9 @@ let simplify_direct_function_call ~simplify_expr dacc apply
~coming_from_indirect ~callee's_code_metadata
else if provided_num_args > num_params
then
simplify_direct_over_application ~simplify_expr dacc apply
~param_arity:params_arity ~result_arity ~down_to_up
~coming_from_indirect ~apply_alloc_mode
~contains_no_escaping_local_allocs:
(Code_metadata.contains_no_escaping_local_allocs
callee's_code_metadata)
~current_region
simplify_direct_over_application ~simplify_expr dacc apply ~result_arity
~down_to_up ~coming_from_indirect ~apply_alloc_mode ~current_region
~callee's_code_id ~callee's_code_metadata
else if provided_num_args > 0 && provided_num_args < num_params
then
simplify_direct_partial_application ~simplify_expr dacc apply
Expand Down
43 changes: 30 additions & 13 deletions middle_end/flambda2/simplify/simplify_common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,20 @@ let project_tuple ~dbg ~size ~field tuple =
let prim = P.Binary (Block_load (bak, mutability), tuple, index) in
Named.create_prim prim dbg

let split_direct_over_application apply ~param_arity ~result_arity
~(apply_alloc_mode : Alloc_mode.For_types.t)
~contains_no_escaping_local_allocs ~current_region =
let arity = Flambda_arity.With_subkinds.cardinal param_arity in
let split_direct_over_application apply ~result_arity
~(apply_alloc_mode : Alloc_mode.For_types.t) ~current_region
~callee's_code_id ~callee's_code_metadata =
let arity =
Flambda_arity.With_subkinds.cardinal
(Code_metadata.params_arity callee's_code_metadata)
in
let args = Apply.args apply in
assert (arity < List.length args);
let first_args, remaining_args = Misc.Stdlib.List.split_at arity args in
let func_var = Variable.create "full_apply" in
let contains_no_escaping_local_allocs =
Code_metadata.contains_no_escaping_local_allocs callee's_code_metadata
in
let needs_region =
(* If the function being called might do a local allocation that escapes,
then we need a region for such function's return value, unless the
Expand All @@ -119,11 +125,6 @@ let split_direct_over_application apply ~param_arity ~result_arity
| None -> current_region
| Some (region, _) -> region
in
let alloc_mode =
if contains_no_escaping_local_allocs
then Alloc_mode.For_types.heap
else Alloc_mode.For_types.unknown ()
in
let continuation =
(* If there is no need for a new region, then the second (over)
application jumps directly to the return continuation of the original
Expand All @@ -137,7 +138,8 @@ let split_direct_over_application apply ~param_arity ~result_arity
Apply.create ~callee:(Simple.var func_var) ~continuation
(Apply.exn_continuation apply)
~args:remaining_args
~call_kind:(Call_kind.indirect_function_call_unknown_arity alloc_mode)
~call_kind:
(Call_kind.indirect_function_call_unknown_arity apply_alloc_mode)
(Apply.dbg apply) ~inlined:(Apply.inlined apply)
~inlining_state:(Apply.inlining_state apply)
~probe_name:(Apply.probe_name apply) ~position:(Apply.position apply)
Expand Down Expand Up @@ -214,9 +216,24 @@ let split_direct_over_application apply ~param_arity ~result_arity
~is_exn_handler:false
in
let full_apply =
Apply.with_continuation_callee_and_args apply
(Return after_full_application) ~callee:(Apply.callee apply)
~args:first_args ~region:current_region
let alloc_mode =
if contains_no_escaping_local_allocs
then Alloc_mode.For_types.heap
else Alloc_mode.For_types.unknown ()
in
Apply.create ~callee:(Apply.callee apply)
~continuation:(Return after_full_application)
(Apply.exn_continuation apply)
~args:first_args
~call_kind:
(Call_kind.direct_function_call callee's_code_id
~return_arity:(Code_metadata.result_arity callee's_code_metadata)
alloc_mode)
(Apply.dbg apply) ~inlined:(Apply.inlined apply)
~inlining_state:(Apply.inlining_state apply)
~probe_name:(Apply.probe_name apply) ~position:(Apply.position apply)
~relative_history:(Apply.relative_history apply)
~region:current_region
in
let both_applications =
Let_cont.create_non_recursive after_full_application
Expand Down
4 changes: 2 additions & 2 deletions middle_end/flambda2/simplify/simplify_common.mli
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ val project_tuple :
application of the leftover arguments. *)
val split_direct_over_application :
Apply_expr.t ->
param_arity:Flambda_arity.With_subkinds.t ->
result_arity:Flambda_arity.With_subkinds.t ->
apply_alloc_mode:Alloc_mode.For_types.t ->
contains_no_escaping_local_allocs:bool ->
current_region:Variable.t ->
callee's_code_id:Code_id.t ->
callee's_code_metadata:Code_metadata.t ->
Expr.t

type apply_cont_context =
Expand Down
5 changes: 0 additions & 5 deletions middle_end/flambda2/terms/apply_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,6 @@ let with_call_kind t call_kind =

let with_args t args = { t with args }

let with_continuation_callee_and_args t continuation ~callee ~args ~region =
let t = { t with continuation; callee; args; region } in
invariant t;
t

let inlining_arguments t = inlining_state t |> Inlining_state.arguments

let probe_name t = t.probe_name
Expand Down
9 changes: 0 additions & 9 deletions middle_end/flambda2/terms/apply_expr.mli
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,6 @@ val with_args : t -> Simple.t list -> t
(** Change the call kind of an application. *)
val with_call_kind : t -> Call_kind.t -> t

(** Change the continuation, callee and arguments of an application. *)
val with_continuation_callee_and_args :
t ->
Result_continuation.t ->
callee:Simple.t ->
args:Simple.t list ->
region:Variable.t ->
t

val inlining_state : t -> Inlining_state.t

val inlining_arguments : t -> Inlining_arguments.t
Expand Down
32 changes: 32 additions & 0 deletions ocaml/testsuite/tests/typing-local/pr902.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
(* TEST
* stack-allocation
** native
*)

(* PR902 (return mode on second application expression in a split
overapplication) *)

external local_stack_offset : unit -> int = "caml_local_stack_offset"
external opaque_identity : ('a[@local_opt]) -> ('a[@local_opt]) = "%opaque"
external is_local : local_ 'a -> bool = "caml_obj_is_local"

let f2 p () = p

let f1 () x : (unit -> local_ (int * int)) =
(* This local allocation should end up in the caller's region, because
we should have got here via one of the caml_applyL functions. If the
return mode of the second application in the expansion of the
overapplication below is wrongly Heap, then caml_apply will be used
instead, which will open its own region for this allocation. *)
let p = local_ (x, x) in
local_ ((opaque_identity f2) p) [@nontail]

let[@inline never] to_be_overapplied () () = Sys.opaque_identity f1

let () =
let start_offset = local_stack_offset () in
let p = to_be_overapplied () () () 42 () in
let end_offset = local_stack_offset () in
assert (is_local p);
let ok = end_offset - start_offset = 64 in
Printf.printf "PR902: %s\n" (if ok then "ok" else "FAIL")
1 change: 1 addition & 0 deletions ocaml/testsuite/tests/typing-local/pr902.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PR902: ok