From 4d155c63819ef8527b3eaf5ccb2f90fd4c802e23 Mon Sep 17 00:00:00 2001 From: Mark Shinwell Date: Thu, 21 Dec 2023 11:31:32 +0000 Subject: [PATCH] flambda-backend: Treat Prim_poly primitives as Alloc_heap in Lambda (#2190) (cherry picked from commit 003afd0b7f2eeeabaa2d621745565c144d5a43a2) --- lambda/lambda.ml | 9 +++-- lambda/printlambda.ml | 36 +++++++++++--------- lambda/printlambda.mli | 1 + lambda/translprim.ml | 42 ++++++++++++++++++++++-- testsuite/tests/typing-local/external.ml | 19 +++++++++++ 5 files changed, 87 insertions(+), 20 deletions(-) create mode 100644 testsuite/tests/typing-local/external.ml diff --git a/lambda/lambda.ml b/lambda/lambda.ml index e17ff440f7d..8a8dfc934d9 100644 --- a/lambda/lambda.ml +++ b/lambda/lambda.ml @@ -1448,16 +1448,19 @@ let alloc_mode_of_primitive_description (p : Primitive.description) = if p.prim_alloc then Some alloc_heap else None else match p.prim_native_repr_res with - | (Prim_local | Prim_poly), _ -> + | Prim_local, _ -> (* For primitives that might allocate locally, [p.prim_alloc] just says whether [caml_c_call] is required, without telling us anything about local allocation. (However if [p.prim_alloc = false] we do actually know that the primitive does not allocate on the heap.) *) Some alloc_local - | Prim_global, _ -> + | (Prim_global | Prim_poly), _ -> (* For primitives that definitely do not allocate locally, [p.prim_alloc = false] actually tells us that the primitive does - not allocate at all. *) + not allocate at all. + + No external call that is [Prim_poly] may allocate locally. + *) if p.prim_alloc then Some alloc_heap else None (* Changes to this function may also require changes in Flambda 2 (e.g. diff --git a/lambda/printlambda.ml b/lambda/printlambda.ml index e9016cfd589..5f739e83f43 100644 --- a/lambda/printlambda.ml +++ b/lambda/printlambda.ml @@ -84,10 +84,15 @@ let array_set_kind ppf k = | Pintarray_set -> fprintf ppf "int" | Pfloatarray_set -> fprintf ppf "float" -let alloc_mode = function +let alloc_mode_if_local = function | Alloc_heap -> "" | Alloc_local -> "local" +let alloc_mode ppf alloc_mode = + match alloc_mode with + | Alloc_heap -> fprintf ppf "heap" + | Alloc_local -> fprintf ppf "local" + let boxed_integer_name = function | Pnativeint -> "nativeint" | Pint32 -> "int32" @@ -143,7 +148,7 @@ let rec layout is_top ppf layout_ = let layout ppf layout_ = layout true ppf layout_ let return_kind ppf (mode, kind) = - let smode = alloc_mode mode in + let smode = alloc_mode_if_local mode in match kind with | Pvalue Pgenval when is_heap_mode mode -> () | Pvalue Pgenval -> fprintf ppf ": %s@ " smode @@ -275,31 +280,31 @@ let primitive ppf = function | Pgetpredef id -> fprintf ppf "getpredef %a!" Ident.print id | Pmakeblock(tag, Immutable, shape, mode) -> fprintf ppf "make%sblock %i%a" - (alloc_mode mode) tag block_shape shape + (alloc_mode_if_local mode) tag block_shape shape | Pmakeblock(tag, Immutable_unique, shape, mode) -> fprintf ppf "make%sblock_unique %i%a" - (alloc_mode mode) tag block_shape shape + (alloc_mode_if_local mode) tag block_shape shape | Pmakeblock(tag, Mutable, shape, mode) -> fprintf ppf "make%smutable %i%a" - (alloc_mode mode) tag block_shape shape + (alloc_mode_if_local mode) tag block_shape shape | Pmakefloatblock (Immutable, mode) -> fprintf ppf "make%sfloatblock Immutable" - (alloc_mode mode) + (alloc_mode_if_local mode) | Pmakefloatblock (Immutable_unique, mode) -> fprintf ppf "make%sfloatblock Immutable_unique" - (alloc_mode mode) + (alloc_mode_if_local mode) | Pmakefloatblock (Mutable, mode) -> fprintf ppf "make%sfloatblock Mutable" - (alloc_mode mode) + (alloc_mode_if_local mode) | Pmakeufloatblock (Immutable, mode) -> fprintf ppf "make%sufloatblock Immutable" - (alloc_mode mode) + (alloc_mode_if_local mode) | Pmakeufloatblock (Immutable_unique, mode) -> fprintf ppf "make%sufloatblock Immutable_unique" - (alloc_mode mode) + (alloc_mode_if_local mode) | Pmakeufloatblock (Mutable, mode) -> fprintf ppf "make%sufloatblock Mutable" - (alloc_mode mode) + (alloc_mode_if_local mode) | Pfield (n, ptr, sem) -> let instr = match ptr, sem with @@ -340,7 +345,7 @@ let primitive ppf = function fprintf ppf "setfield_%s%s_computed" instr init | Pfloatfield (n, sem, mode) -> fprintf ppf "floatfield%a%s %i" - field_read_semantics sem (alloc_mode mode) n + field_read_semantics sem (alloc_mode_if_local mode) n | Pufloatfield (n, sem) -> fprintf ppf "ufloatfield%a %i" field_read_semantics sem n @@ -419,11 +424,12 @@ let primitive ppf = function | Parraylength k -> fprintf ppf "array.length[%s]" (array_kind k) | Pmakearray (k, Mutable, mode) -> - fprintf ppf "make%sarray[%s]" (alloc_mode mode) (array_kind k) + fprintf ppf "make%sarray[%s]" (alloc_mode_if_local mode) (array_kind k) | Pmakearray (k, Immutable, mode) -> - fprintf ppf "make%sarray_imm[%s]" (alloc_mode mode) (array_kind k) + fprintf ppf "make%sarray_imm[%s]" (alloc_mode_if_local mode) (array_kind k) | Pmakearray (k, Immutable_unique, mode) -> - fprintf ppf "make%sarray_unique[%s]" (alloc_mode mode) (array_kind k) + fprintf ppf "make%sarray_unique[%s]" (alloc_mode_if_local mode) + (array_kind k) | Pduparray (k, Mutable) -> fprintf ppf "duparray[%s]" (array_kind k) | Pduparray (k, Immutable) -> fprintf ppf "duparray_imm[%s]" (array_kind k) | Pduparray (k, Immutable_unique) -> diff --git a/lambda/printlambda.mli b/lambda/printlambda.mli index 9f56afb32ff..c982456ea8e 100644 --- a/lambda/printlambda.mli +++ b/lambda/printlambda.mli @@ -36,3 +36,4 @@ val print_bigarray : string -> bool -> Lambda.bigarray_kind -> formatter -> Lambda.bigarray_layout -> unit val check_attribute : formatter -> check_attribute -> unit +val alloc_mode : formatter -> alloc_mode -> unit diff --git a/lambda/translprim.ml b/lambda/translprim.ml index 93365c0e4bd..efcb3bf4cd2 100644 --- a/lambda/translprim.ml +++ b/lambda/translprim.ml @@ -915,7 +915,11 @@ let lambda_of_prim prim_name prim loc args arg_exps = let check_primitive_arity loc p = let mode = match p.prim_native_repr_res with - | Prim_global, _ | Prim_poly, _ -> Some Mode.Locality.global + | Prim_global, _ | Prim_poly, _ -> + (* We assume all primitives are compiled to have the same arity for + different modes and types, so just pick one of the modes in the + [Prim_poly] case. *) + Some Mode.Locality.global | Prim_local, _ -> Some Mode.Locality.local in let prim = lookup_primitive loc mode Rc_normal p in @@ -987,8 +991,42 @@ let transl_primitive loc p env ty ~poly_mode path = loc in let body = lambda_of_prim p.prim_name prim loc args None in + let alloc_mode = to_locality p.prim_native_repr_res in + let () = + (* CR mshinwell: Write a version of [primitive_may_allocate] that + works on the [prim] type. *) + match body with + | Lprim (prim, _, _) -> + (match Lambda.primitive_may_allocate prim with + | None -> + (* We don't check anything in this case; if the primitive doesn't + allocate, then after [Lambda] it will be translated to a term + not involving any region variables, meaning there would be + no concern about potentially unbound region variables. *) + () + | Some lambda_alloc_mode -> + (* In this case we add a check to ensure the middle end has + the correct information as to whether a region was inserted + at this point. *) + match alloc_mode, lambda_alloc_mode with + | Alloc_heap, Alloc_heap + | Alloc_local, Alloc_local -> () + | Alloc_local, Alloc_heap -> + (* This case is ok: the Lambda-derived information is more + precise. A region will be inserted, likely unused, and + deleted by the middle end. *) + () + | Alloc_heap, Alloc_local -> + Misc.fatal_errorf "Alloc mode incompatibility for:@ %a@ \ + (from to_locality, %a; from primitive_may_allocate, %a)" + Printlambda.lambda body + Printlambda.alloc_mode alloc_mode + Printlambda.alloc_mode lambda_alloc_mode + ) + | _ -> () + in let region = - match to_locality p.prim_native_repr_res with + match alloc_mode with | Alloc_heap -> true | Alloc_local -> false in diff --git a/testsuite/tests/typing-local/external.ml b/testsuite/tests/typing-local/external.ml new file mode 100644 index 00000000000..c1411bdcce2 --- /dev/null +++ b/testsuite/tests/typing-local/external.ml @@ -0,0 +1,19 @@ +(* TEST + * flambda2 + ** native +*) + +module M : sig + val bits_of_float : float -> int64 +end = struct + external bits_of_float + : (float[@local_opt]) + -> (int64[@local_opt]) + = "caml_int64_bits_of_float" "caml_int64_bits_of_float_unboxed" + [@@unboxed] [@@noalloc] +end + +let go_m f = + let i = M.bits_of_float f in + assert (i = 4L); + ()