Skip to content

Commit

Permalink
Specialize array_get primitives for iarrays (ocaml-flambda#3082)
Browse files Browse the repository at this point in the history
* specialize array_get for iarrays

* format

* add lambda tests

* improve expect diff
  • Loading branch information
TheNumbat authored Sep 30, 2024
1 parent 5a46c46 commit 7f34929
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 61 deletions.
38 changes: 21 additions & 17 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -834,25 +834,27 @@ let check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
~size_int)
~dbg

let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t)
~current_region : H.expr_primitive =
let array_load_unsafe ~array ~index ~(mut : Lambda.mutable_flag)
(array_ref_kind : Array_ref_kind.t) ~current_region : H.expr_primitive =
let mut : Mutability.t =
match mut with
| Immutable | Immutable_unique -> Immutable
| Mutable -> Mutable
in
match array_ref_kind with
| Immediates -> Binary (Array_load (Immediates, Scalar, Mutable), array, index)
| Values -> Binary (Array_load (Values, Scalar, Mutable), array, index)
| Immediates -> Binary (Array_load (Immediates, Scalar, mut), array, index)
| Values -> Binary (Array_load (Values, Scalar, mut), array, index)
| Naked_floats_to_be_boxed mode ->
box_float mode
(Binary (Array_load (Naked_floats, Scalar, Mutable), array, index))
(Binary (Array_load (Naked_floats, Scalar, mut), array, index))
~current_region
| Naked_floats ->
Binary (Array_load (Naked_floats, Scalar, Mutable), array, index)
| Naked_floats -> Binary (Array_load (Naked_floats, Scalar, mut), array, index)
| Naked_float32s ->
Binary (Array_load (Naked_float32s, Scalar, Mutable), array, index)
| Naked_int32s ->
Binary (Array_load (Naked_int32s, Scalar, Mutable), array, index)
| Naked_int64s ->
Binary (Array_load (Naked_int64s, Scalar, Mutable), array, index)
Binary (Array_load (Naked_float32s, Scalar, mut), array, index)
| Naked_int32s -> Binary (Array_load (Naked_int32s, Scalar, mut), array, index)
| Naked_int64s -> Binary (Array_load (Naked_int64s, Scalar, mut), array, index)
| Naked_nativeints ->
Binary (Array_load (Naked_nativeints, Scalar, Mutable), array, index)
Binary (Array_load (Naked_nativeints, Scalar, mut), array, index)

let array_set_unsafe ~array ~index ~new_value
(array_set_kind : Array_set_kind.t) : H.expr_primitive =
Expand Down Expand Up @@ -1584,19 +1586,19 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Pmodbint { size = Pnativeint; is_safe = Safe; mode }, [[arg1]; [arg2]] ->
[ checked_arith_op ~dbg (Some Pnativeint) Mod (Some mode) arg1 arg2
~current_region ]
| Parrayrefu (array_ref_kind, index_kind), [[array]; [index]] ->
| Parrayrefu (array_ref_kind, index_kind, mut), [[array]; [index]] ->
(* For this and the following cases we will end up relying on the backend to
CSE the two accesses to the array's header word in the [Pgenarray]
case. *)
[ match_on_array_ref_kind ~array array_ref_kind
(array_load_unsafe ~array
(array_load_unsafe ~array ~mut
~index:(convert_index_to_tagged_int ~index ~index_kind)
~current_region) ]
| Parrayrefs (array_ref_kind, index_kind), [[array]; [index]] ->
| Parrayrefs (array_ref_kind, index_kind, mut), [[array]; [index]] ->
let array_kind = convert_array_ref_kind_for_length array_ref_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_ref_kind ~array array_ref_kind
(array_load_unsafe ~array
(array_load_unsafe ~array ~mut
~index:(convert_index_to_tagged_int ~index ~index_kind)
~current_region)) ]
| Parraysetu (array_set_kind, index_kind), [[array]; [index]; [new_value]] ->
Expand Down Expand Up @@ -1976,11 +1978,13 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
( ( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref
| Pfloatarray_ref _ | Punboxedfloatarray_ref _
| Punboxedintarray_ref _ ),
_,
_ )
| Parrayrefs
( ( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref
| Pfloatarray_ref _ | Punboxedfloatarray_ref _
| Punboxedintarray_ref _ ),
_,
_ )
| Pcompare_ints | Pcompare_floats _ | Pcompare_bints _ | Patomic_exchange
| Patomic_fetch_add ),
Expand Down
16 changes: 8 additions & 8 deletions ocaml/bytecomp/bytegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -505,15 +505,15 @@ let comp_primitive stack_info p sz args =
(* In bytecode, nothing is ever actually stack-allocated, so we ignore the
array modes (allocation for [Parrayref{s,u}], modification for
[Parrayset{s,u}]). *)
| Parrayrefs (Pgenarray_ref _, index_kind)
| Parrayrefs (Pgenarray_ref _, index_kind, _)
| Parrayrefs ((Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _
| Punboxedfloatarray_ref (Pfloat64 | Pfloat32) | Punboxedintarray_ref _),
(Punboxed_int_index _ as index_kind)) ->
(Punboxed_int_index _ as index_kind), _) ->
Kccall(indexing_primitive index_kind "caml_array_get", 2)
| Parrayrefs ((Punboxedfloatarray_ref Pfloat64 | Pfloatarray_ref _), Ptagged_int_index) ->
| Parrayrefs ((Punboxedfloatarray_ref Pfloat64 | Pfloatarray_ref _), Ptagged_int_index, _) ->
Kccall("caml_floatarray_get", 2)
| Parrayrefs ((Punboxedfloatarray_ref Pfloat32 | Punboxedintarray_ref _
| Paddrarray_ref | Pintarray_ref), Ptagged_int_index) ->
| Paddrarray_ref | Pintarray_ref), Ptagged_int_index, _) ->
Kccall("caml_array_get_addr", 2)
| Parraysets (Pgenarray_set _, index_kind)
| Parraysets ((Paddrarray_set _ | Pintarray_set | Pfloatarray_set
Expand All @@ -526,15 +526,15 @@ let comp_primitive stack_info p sz args =
| Parraysets ((Punboxedfloatarray_set Pfloat32 | Punboxedintarray_set _
| Paddrarray_set _ | Pintarray_set), Ptagged_int_index) ->
Kccall("caml_array_set_addr", 3)
| Parrayrefu (Pgenarray_ref _, index_kind)
| Parrayrefu (Pgenarray_ref _, index_kind, _)
| Parrayrefu ((Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _
| Punboxedfloatarray_ref (Pfloat64 | Pfloat32) | Punboxedintarray_ref _),
(Punboxed_int_index _ as index_kind)) ->
(Punboxed_int_index _ as index_kind), _) ->
Kccall(indexing_primitive index_kind "caml_array_unsafe_get", 2)
| Parrayrefu ((Punboxedfloatarray_ref Pfloat64 | Pfloatarray_ref _), Ptagged_int_index) ->
| Parrayrefu ((Punboxedfloatarray_ref Pfloat64 | Pfloatarray_ref _), Ptagged_int_index, _) ->
Kccall("caml_floatarray_unsafe_get", 2)
| Parrayrefu ((Punboxedfloatarray_ref Pfloat32 | Punboxedintarray_ref _
| Paddrarray_ref | Pintarray_ref), Ptagged_int_index) -> Kgetvectitem
| Paddrarray_ref | Pintarray_ref), Ptagged_int_index, _) -> Kgetvectitem
| Parraysetu (Pgenarray_set _, index_kind)
| Parraysetu ((Paddrarray_set _ | Pintarray_set | Pfloatarray_set
| Punboxedfloatarray_set (Pfloat64 | Pfloat32) | Punboxedintarray_set _),
Expand Down
16 changes: 8 additions & 8 deletions ocaml/lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ type primitive =
| Pmakearray of array_kind * mutable_flag * alloc_mode
| Pduparray of array_kind * mutable_flag
| Parraylength of array_kind
| Parrayrefu of array_ref_kind * array_index_kind
| Parrayrefu of array_ref_kind * array_index_kind * mutable_flag
| Parraysetu of array_set_kind * array_index_kind
| Parrayrefs of array_ref_kind * array_index_kind
| Parrayrefs of array_ref_kind * array_index_kind * mutable_flag
| Parraysets of array_set_kind * array_index_kind
(* Test if the argument is a block or an immediate integer *)
| Pisint of { variant_only : bool }
Expand Down Expand Up @@ -1757,11 +1757,11 @@ let primitive_may_allocate : primitive -> alloc_mode option = function
| Parraylength _ -> None
| Parraysetu _ | Parraysets _
| Parrayrefu ((Paddrarray_ref | Pintarray_ref
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _)
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _, _)
| Parrayrefs ((Paddrarray_ref | Pintarray_ref
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _) -> None
| Parrayrefu ((Pgenarray_ref m | Pfloatarray_ref m), _)
| Parrayrefs ((Pgenarray_ref m | Pfloatarray_ref m), _) -> Some m
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _), _, _) -> None
| Parrayrefu ((Pgenarray_ref m | Pfloatarray_ref m), _, _)
| Parrayrefs ((Pgenarray_ref m | Pfloatarray_ref m), _, _) -> Some m
| Pisint _ | Pisout -> None
| Pintofbint _ -> None
| Pbintofint (_,m)
Expand Down Expand Up @@ -1961,7 +1961,7 @@ let primitive_result_layout (p : primitive) =
| Pstring_load_16 _ | Pbytes_load_16 _ | Pbigstring_load_16 _
| Pprobe_is_enabled _ | Pbswap16
-> layout_int
| Parrayrefu (array_ref_kind, _) | Parrayrefs (array_ref_kind, _) ->
| Parrayrefu (array_ref_kind, _, _) | Parrayrefs (array_ref_kind, _, _) ->
array_ref_kind_result_layout array_ref_kind
| Pbintofint (bi, _) | Pcvtbint (_,bi,_)
| Pnegbint (bi, _) | Paddbint (bi, _) | Psubbint (bi, _)
Expand All @@ -1983,7 +1983,7 @@ let primitive_result_layout (p : primitive) =
| Pstring_load_128 _ | Pbytes_load_128 _
| Pbigstring_load_128 { boxed = true; _ } ->
layout_boxed_vector (Pvec128 Int8x16)
| Pbigstring_load_32 { boxed = false; _ }
| Pbigstring_load_32 { boxed = false; _ }
| Pstring_load_32 { boxed = false; _ }
| Pbytes_load_32 { boxed = false; _ } -> layout_unboxed_int Pint32
| Pbigstring_load_f32 { boxed = false; _ }
Expand Down
4 changes: 2 additions & 2 deletions ocaml/lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ type primitive =
The arguments of [Pduparray] give the kind and mutability of the
array being *produced* by the duplication. *)
| Parraylength of array_kind
| Parrayrefu of array_ref_kind * array_index_kind
| Parrayrefu of array_ref_kind * array_index_kind * mutable_flag
| Parraysetu of array_set_kind * array_index_kind
| Parrayrefs of array_ref_kind * array_index_kind
| Parrayrefs of array_ref_kind * array_index_kind * mutable_flag
| Parraysets of array_set_kind * array_index_kind
(* Test if the argument is a block or an immediate integer *)
| Pisint of { variant_only : bool }
Expand Down
3 changes: 2 additions & 1 deletion ocaml/lambda/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2381,8 +2381,9 @@ let get_expr_args_array ~scopes kind head (arg, _mut, _sort, _layout) rem =
array pattern, once that's available *)
let ref_kind = Lambda.(array_ref_kind alloc_heap kind) in
let result_layout = array_ref_kind_result_layout ref_kind in
let mut = if Types.is_mutable am then Mutable else Immutable in
( Lprim
(Parrayrefu (ref_kind, Ptagged_int_index),
(Parrayrefu (ref_kind, Ptagged_int_index, mut),
[ arg; Lconst (Const_base (Const_int pos)) ],
loc),
(if Types.is_mutable am then StrictOpt else Alias),
Expand Down
18 changes: 12 additions & 6 deletions ocaml/lambda/printlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ let array_kind = function
| Punboxedintarray Pint64 -> "unboxed_int64"
| Punboxedintarray Pnativeint -> "unboxed_nativeint"

let array_mut = function
| Mutable -> "array"
| Immutable | Immutable_unique -> "iarray"

let array_ref_kind ppf k =
let pp_mode ppf = function
| Alloc_heap -> ()
Expand Down Expand Up @@ -582,15 +586,17 @@ let primitive ppf = function
| Pduparray (k, Immutable) -> fprintf ppf "duparray_imm[%s]" (array_kind k)
| Pduparray (k, Immutable_unique) ->
fprintf ppf "duparray_unique[%s]" (array_kind k)
| Parrayrefu (rk, idx) -> fprintf ppf "array.unsafe_get[%a indexed by %a]"
array_ref_kind rk
array_index_kind idx
| Parrayrefu (rk, idx, mut) -> fprintf ppf "%s.unsafe_get[%a indexed by %a]"
(array_mut mut)
array_ref_kind rk
array_index_kind idx
| Parraysetu (sk, idx) -> fprintf ppf "array.unsafe_set[%a indexed by %a]"
array_set_kind sk
array_index_kind idx
| Parrayrefs (rk, idx) -> fprintf ppf "array.get[%a indexed by %a]"
array_ref_kind rk
array_index_kind idx
| Parrayrefs (rk, idx, mut) -> fprintf ppf "%s.get[%a indexed by %a]"
(array_mut mut)
array_ref_kind rk
array_index_kind idx
| Parraysets (sk, idx) -> fprintf ppf "array.set[%a indexed by %a]"
array_set_kind sk
array_index_kind idx
Expand Down
6 changes: 5 additions & 1 deletion ocaml/lambda/transl_array_comprehension.ml
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ let iterator ~transl_exp ~scopes ~loc :
Typeopt.array_type_kind ~elt_sort:None iter_arr_exp.exp_env
iter_arr_exp.exp_loc iter_arr_exp.exp_type
in
let iter_arr_mut =
Typeopt.array_type_mut iter_arr_exp.exp_env iter_arr_exp.exp_type
in
let iter_len =
(* Extra let-binding if we're not in the fixed-size array case; the
middle-end will simplify this for us *)
Expand All @@ -498,7 +501,8 @@ let iterator ~transl_exp ~scopes ~loc :
(Lprim
( Parrayrefu
( Lambda.(array_ref_kind alloc_heap iter_arr_kind),
Ptagged_int_index ),
Ptagged_int_index,
iter_arr_mut ),
[iter_arr.var; Lvar iter_ix],
loc ))
pattern body
Expand Down
38 changes: 20 additions & 18 deletions ocaml/lambda/translprim.ml
Original file line number Diff line number Diff line change
Expand Up @@ -454,78 +454,78 @@ let lookup_primitive loc ~poly_mode ~poly_sort pos p =
| "%array_length" -> Primitive ((Parraylength gen_array_kind), 1)
| "%array_safe_get" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Ptagged_int_index)), 2)
((Parrayrefs (gen_array_ref_kind mode, Ptagged_int_index, Mutable)), 2)
| "%array_safe_set" ->
Primitive
(Parraysets (gen_array_set_kind (get_first_arg_mode ()), Ptagged_int_index),
3)
| "%array_unsafe_get" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Ptagged_int_index), 2)
(Parrayrefu (gen_array_ref_kind mode, Ptagged_int_index, Mutable), 2)
| "%array_unsafe_set" ->
Primitive
((Parraysetu (gen_array_set_kind (get_first_arg_mode ()), Ptagged_int_index)),
3)
| "%array_safe_get_indexed_by_int64#" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pint64)), 2)
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pint64, Mutable)), 2)
| "%array_safe_set_indexed_by_int64#" ->
Primitive
(Parraysets
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint64),
3)
| "%array_unsafe_get_indexed_by_int64#" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pint64), 2)
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pint64, Mutable), 2)
| "%array_unsafe_set_indexed_by_int64#" ->
Primitive
((Parraysetu
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint64)),
3)
| "%array_safe_get_indexed_by_int32#" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pint32)), 2)
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pint32, Mutable)), 2)
| "%array_safe_set_indexed_by_int32#" ->
Primitive
(Parraysets
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint32),
3)
| "%array_unsafe_get_indexed_by_int32#" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pint32), 2)
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pint32, Mutable), 2)
| "%array_unsafe_set_indexed_by_int32#" ->
Primitive
((Parraysetu
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pint32)),
3)
| "%array_safe_get_indexed_by_nativeint#" ->
Primitive
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pnativeint)), 2)
((Parrayrefs (gen_array_ref_kind mode, Punboxed_int_index Pnativeint, Mutable)), 2)
| "%array_safe_set_indexed_by_nativeint#" ->
Primitive
(Parraysets
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pnativeint),
3)
| "%array_unsafe_get_indexed_by_nativeint#" ->
Primitive
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pnativeint), 2)
(Parrayrefu (gen_array_ref_kind mode, Punboxed_int_index Pnativeint, Mutable), 2)
| "%array_unsafe_set_indexed_by_nativeint#" ->
Primitive
((Parraysetu
(gen_array_set_kind (get_first_arg_mode ()), Punboxed_int_index Pnativeint)),
3)
| "%obj_size" -> Primitive ((Parraylength Pgenarray), 1)
| "%obj_field" -> Primitive ((Parrayrefu (Pgenarray_ref mode, Ptagged_int_index)), 2)
| "%obj_field" -> Primitive ((Parrayrefu (Pgenarray_ref mode, Ptagged_int_index, Mutable)), 2)
| "%obj_set_field" ->
Primitive
((Parraysetu (Pgenarray_set (get_first_arg_mode ()), Ptagged_int_index)), 3)
| "%floatarray_length" -> Primitive ((Parraylength Pfloatarray), 1)
| "%floatarray_safe_get" ->
Primitive ((Parrayrefs (Pfloatarray_ref mode, Ptagged_int_index)), 2)
Primitive ((Parrayrefs (Pfloatarray_ref mode, Ptagged_int_index, Mutable)), 2)
| "%floatarray_safe_set" ->
Primitive (Parraysets (Pfloatarray_set, Ptagged_int_index), 3)
| "%floatarray_unsafe_get" ->
Primitive ((Parrayrefu (Pfloatarray_ref mode, Ptagged_int_index)), 2)
Primitive ((Parrayrefu (Pfloatarray_ref mode, Ptagged_int_index, Mutable)), 2)
| "%floatarray_unsafe_set" ->
Primitive ((Parraysetu (Pfloatarray_set, Ptagged_int_index)), 3)
| "%obj_is_int" -> Primitive (Pisint { variant_only = false }, 1)
Expand Down Expand Up @@ -1012,13 +1012,14 @@ let specialize_primitive env loc ty ~has_constant_constructor prim =
if t = array_type then None
else Some (Primitive (Parraylength array_type, arity))
end
| Primitive (Parrayrefu (rt, index_kind), arity), p1 :: _ -> begin
| Primitive (Parrayrefu (rt, index_kind, mut), arity), p1 :: _ -> begin
let loc = to_location loc in
let array_ref_type =
glb_array_ref_type loc rt (array_type_kind ~elt_sort:None env loc p1)
in
if rt = array_ref_type then None
else Some (Primitive (Parrayrefu (array_ref_type, index_kind), arity))
let array_mut = array_type_mut env p1 in
if rt = array_ref_type && mut = array_mut then None
else Some (Primitive (Parrayrefu (array_ref_type, index_kind, array_mut), arity))
end
| Primitive (Parraysetu (st, index_kind), arity), p1 :: _ -> begin
let loc = to_location loc in
Expand All @@ -1028,13 +1029,14 @@ let specialize_primitive env loc ty ~has_constant_constructor prim =
if st = array_set_type then None
else Some (Primitive (Parraysetu (array_set_type, index_kind), arity))
end
| Primitive (Parrayrefs (rt, index_kind), arity), p1 :: _ -> begin
| Primitive (Parrayrefs (rt, index_kind, mut), arity), p1 :: _ -> begin
let loc = to_location loc in
let array_ref_type =
glb_array_ref_type loc rt (array_type_kind ~elt_sort:None env loc p1)
in
if rt = array_ref_type then None
else Some (Primitive (Parrayrefs (array_ref_type, index_kind), arity))
let array_mut = array_type_mut env p1 in
if rt = array_ref_type && mut = array_mut then None
else Some (Primitive (Parrayrefs (array_ref_type, index_kind, array_mut), arity))
end
| Primitive (Parraysets (st, index_kind), arity), p1 :: _ -> begin
let loc = to_location loc in
Expand Down Expand Up @@ -1528,7 +1530,7 @@ let lambda_primitive_needs_event_after = function
| Pmulfloat (_, _) | Pdivfloat (_, _)
| Pstringrefs | Pbytesrefs
| Pbytessets | Pmakearray (Pgenarray, _, _) | Pduparray _
| Parrayrefu ((Pgenarray_ref _ | Pfloatarray_ref _), _)
| Parrayrefu ((Pgenarray_ref _ | Pfloatarray_ref _), _, _)
| Parrayrefs _ | Parraysets _ | Pbintofint _ | Pcvtbint _ | Pnegbint _
| Paddbint _ | Psubbint _ | Pmulbint _ | Pdivbint _ | Pmodbint _ | Pandbint _
| Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _ | Pasrbint _
Expand Down
Loading

0 comments on commit 7f34929

Please sign in to comment.