Skip to content

128-bit load/store primitives for GC'd arrays #2247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 28, 2024
Merged
45 changes: 45 additions & 0 deletions middle_end/flambda2/docs/simd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

# SIMD in flambda2

## Vector types

<!-- CR mslater: more docs -->

## Intrinsics

<!-- CR mslater: more docs -->

## Load/Store

Unlike intrinsics, SIMD loads and stores are represented as flambda2-visible primitives.

- `string`, `bytes`: `caml_{string,bytes}_{getu128,setu128}{u}` map to `MOVUPD`, an unaligned 128-bit vector load/store.
The primitives can operate on all 128-bit vector types.
The safe primitives raise `Invalid_argument` if any part of the vector is not within the array bounds; the `u` suffix omits this check.
Aligned load/store is not available because these values may be moved by the GC.

- `bigstring`: `caml_bigstring_{get,set}{u}128{u}` map to `MOVAPD` or `MOVUPD`.
The primitives can operate on all 128-bit vector types.
The prefix `u` indicates an unaligned operation (`MOVUPD`), and the suffix `u` omits bounds checking.
Aligned load/store is available because bigstrings are allocated by `malloc`.

- `float array`, `floatarray`, `float# array`: the corresponding primitives take an index in `float`s and are required to operate on `float64x2`s.
The address is computed as `array + index * 8`; the safe primitives bounds-check against `0, length - 1`.
The primitives on `float array` are only available when the float array optimization is enabled.
Aligned load/store is not available because these values may be moved by the GC.

- `nativeint# array`, `int64# array`: the corresponding primitives take an index in `nativeint`s/`int64`s and are required to operate on `int64x2`s.
The address is computed as `array + index * 8`; the safe primitives bounds-check against `0, length - 1`.
The primitives on `nativeint# array` are only available in 64-bit mode.
Aligned load/store is not available because these values may be moved by the GC.

- `int32# array`: the corresponding primitives take an index in `int32`s and are required to operate on `int32x4`s.
The address is computed as `array + index * 4`; the safe primitives bounds-check against `0, length - 3`.
Aligned load/store is not available because these values may be moved by the GC.

- `%immediate64 array`: the corresponding primitives take an index in immediates, and are required to operate on `int64x2`s.
The primitives can operate on all `('a : immediate64) array`s and are only available in 64-bit mode.
The address is computed as `array + index * 8`; the safe primitives bounds-check against `0, length - 1`.
Aligned load/store is not available because these values may be moved by the GC.
Load/store directly reads/writes two 64-bit **tagged** values. The "safe" primitives do not check for proper tagging,
so are not to be exposed to users as "safe."
12 changes: 9 additions & 3 deletions middle_end/flambda2/from_lambda/closure_conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,15 @@ let close_primitive acc env ~let_bound_ids_with_kinds named
| Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ | Pbytes_set_128 _
| Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
| Pbigstring_load_128 _ | Pbigstring_set_16 _ | Pbigstring_set_32 _
| Pbigstring_set_64 _ | Pbigstring_set_128 _ | Pctconst _ | Pbswap16
| Pbbswap _ | Pint_as_pointer _ | Popaque _ | Pprobe_is_enabled _
| Pobj_dup | Pobj_magic _ | Punbox_float _
| Pbigstring_set_64 _ | Pbigstring_set_128 _ | Pfloatarray_load_128 _
| Pfloat_array_load_128 _ | Pint_array_load_128 _
| Punboxed_float_array_load_128 _ | Punboxed_int32_array_load_128 _
| Punboxed_int64_array_load_128 _ | Punboxed_nativeint_array_load_128 _
| Pfloatarray_set_128 _ | Pfloat_array_set_128 _ | Pint_array_set_128 _
| Punboxed_float_array_set_128 _ | Punboxed_int32_array_set_128 _
| Punboxed_int64_array_set_128 _ | Punboxed_nativeint_array_set_128 _
| Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _ | Popaque _
| Pprobe_is_enabled _ | Pobj_dup | Pobj_magic _ | Punbox_float _
| Pbox_float (_, _)
| Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _
| Punboxed_product_field _ | Pget_header _ | Prunstack | Pperform
Expand Down
28 changes: 28 additions & 0 deletions middle_end/flambda2/from_lambda/lambda_to_flambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,20 @@ let primitive_can_raise (prim : Lambda.primitive) =
| Pbigstring_set_32 { unsafe = false; boxed = _ }
| Pbigstring_set_64 { unsafe = false; boxed = _ }
| Pbigstring_set_128 { unsafe = false; _ }
| Pfloatarray_load_128 { unsafe = false; _ }
| Pfloat_array_load_128 { unsafe = false; _ }
| Pint_array_load_128 { unsafe = false; _ }
| Punboxed_float_array_load_128 { unsafe = false; _ }
| Punboxed_int32_array_load_128 { unsafe = false; _ }
| Punboxed_int64_array_load_128 { unsafe = false; _ }
| Punboxed_nativeint_array_load_128 { unsafe = false; _ }
| Pfloatarray_set_128 { unsafe = false; _ }
| Pfloat_array_set_128 { unsafe = false; _ }
| Pint_array_set_128 { unsafe = false; _ }
| Punboxed_float_array_set_128 { unsafe = false; _ }
| Punboxed_int32_array_set_128 { unsafe = false; _ }
| Punboxed_int64_array_set_128 { unsafe = false; _ }
| Punboxed_nativeint_array_set_128 { unsafe = false; _ }
| Pdivbint { is_safe = Safe; _ }
| Pmodbint { is_safe = Safe; _ }
| Pbigarrayref (false, _, _, _)
Expand Down Expand Up @@ -664,6 +678,20 @@ let primitive_can_raise (prim : Lambda.primitive) =
| Pbigstring_set_32 { unsafe = true; boxed = _ }
| Pbigstring_set_64 { unsafe = true; boxed = _ }
| Pbigstring_set_128 { unsafe = true; _ }
| Pfloatarray_load_128 { unsafe = true; _ }
| Pfloat_array_load_128 { unsafe = true; _ }
| Pint_array_load_128 { unsafe = true; _ }
| Punboxed_float_array_load_128 { unsafe = true; _ }
| Punboxed_int32_array_load_128 { unsafe = true; _ }
| Punboxed_int64_array_load_128 { unsafe = true; _ }
| Punboxed_nativeint_array_load_128 { unsafe = true; _ }
| Pfloatarray_set_128 { unsafe = true; _ }
| Pfloat_array_set_128 { unsafe = true; _ }
| Pint_array_set_128 { unsafe = true; _ }
| Punboxed_float_array_set_128 { unsafe = true; _ }
| Punboxed_int32_array_set_128 { unsafe = true; _ }
| Punboxed_int64_array_set_128 { unsafe = true; _ }
| Punboxed_nativeint_array_set_128 { unsafe = true; _ }
| Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _ | Popaque _
| Pprobe_is_enabled _ | Pobj_dup | Pobj_magic _
| Pbox_float (_, _)
Expand Down
165 changes: 149 additions & 16 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ let convert_block_shape (shape : L.block_shape) ~num_fields =
num_fields shape_length;
List.map K.With_subkind.from_lambda_value_kind shape

let check_float_array_optimisation_enabled () =
let check_float_array_optimisation_enabled name =
if not (Flambda_features.flat_float_array ())
then
Misc.fatal_error
"[Pgenarray] is not expected when the float array optimisation is \
disabled"
Misc.fatal_errorf
"[%s] is not expected when the float array optimisation is disabled" name
()

type converted_array_kind =
| Array_kind of P.Array_kind.t
Expand All @@ -124,7 +124,7 @@ type converted_array_kind =
let convert_array_kind (kind : L.array_kind) : converted_array_kind =
match kind with
| Pgenarray ->
check_float_array_optimisation_enabled ();
check_float_array_optimisation_enabled "Pgenarray";
Float_array_opt_dynamic
| Paddrarray -> Array_kind Values
| Pintarray -> Array_kind Immediates
Expand Down Expand Up @@ -257,7 +257,7 @@ let convert_array_kind_to_duplicate_array_kind (kind : L.array_kind) :
converted_duplicate_array_kind =
match kind with
| Pgenarray ->
check_float_array_optimisation_enabled ();
check_float_array_optimisation_enabled "Pgenarray";
Float_array_opt_dynamic
| Paddrarray -> Duplicate_array_kind Values
| Pintarray -> Duplicate_array_kind Immediates
Expand Down Expand Up @@ -565,6 +565,78 @@ let bytes_like_set_safe ~dbg ~size_int ~access_size kind ~boxed bytes index
new_value)
bytes index

(* Array vector load/store *)

let array_vector_access_validity_condition array ~size_int
(array_kind : P.Array_kind.t) index =
let width_in_scalars =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."
in
let length_untagged =
untag_int (H.Prim (Unary (Array_length (Array_kind array_kind), array)))
in
let reduced_length_untagged =
H.Prim
(Binary
( Int_arith (Naked_immediate, Sub),
length_untagged,
Simple
(Simple.untagged_const_int
(Targetint_31_63.of_int (width_in_scalars - 1))) ))
in
(* We need to convert the length into a naked_nativeint because the optimised
version of the max_with_zero function needs to be on machine-width integers
to work (or at least on an integer number of bytes to work). *)
let reduced_length_nativeint =
H.Prim
(Unary
( Num_conv { src = Naked_immediate; dst = Naked_nativeint },
reduced_length_untagged ))
in
let check_nativeint = max_with_zero ~size_int reduced_length_nativeint in
let check_untagged =
H.Prim
(Unary
( Num_conv { src = Naked_nativeint; dst = Naked_immediate },
check_nativeint ))
in
check_bound_tagged index check_untagged

let check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive
: H.expr_primitive =
checked_access ~primitive
~conditions:
[array_vector_access_validity_condition ~size_int array array_kind index]
~dbg

let array_like_load_128 ~dbg ~size_int ~unsafe ~mode ~current_region array_kind
array index =
let primitive =
box_vec128 mode ~current_region
(H.Binary (Array_load (array_kind, Vec128, Mutable), array, index))
in
if unsafe
then primitive
else
check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive

let array_like_set_128 ~dbg ~size_int ~unsafe array_kind array index new_value =
let primitive =
H.Ternary
(Array_set (array_kind, Vec128), array, index, unbox_vec128 new_value)
in
if unsafe
then primitive
else
check_array_vector_access ~dbg ~size_int ~array
(P.Array_set_kind.array_kind array_kind)
~index primitive

(* Bigarray accesses *)
let bigarray_box_or_tag_raw_value_to_read kind alloc_mode =
let error what =
Expand Down Expand Up @@ -688,17 +760,20 @@ let check_array_access ~dbg ~array array_kind ~index primitive :
let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t)
~current_region : H.expr_primitive =
match array_ref_kind with
| Immediates -> Binary (Array_load (Immediates, Mutable), array, index)
| Values -> Binary (Array_load (Values, Mutable), array, index)
| Immediates -> Binary (Array_load (Immediates, Scalar, Mutable), array, index)
| Values -> Binary (Array_load (Values, Scalar, Mutable), array, index)
| Naked_floats_to_be_boxed mode ->
box_float mode
(Binary (Array_load (Naked_floats, Mutable), array, index))
(Binary (Array_load (Naked_floats, Scalar, Mutable), array, index))
~current_region
| Naked_floats -> Binary (Array_load (Naked_floats, Mutable), array, index)
| Naked_int32s -> Binary (Array_load (Naked_int32s, Mutable), array, index)
| Naked_int64s -> Binary (Array_load (Naked_int64s, Mutable), array, index)
| Naked_floats ->
Binary (Array_load (Naked_floats, Scalar, Mutable), array, index)
| Naked_int32s ->
Binary (Array_load (Naked_int32s, Scalar, Mutable), array, index)
| Naked_int64s ->
Binary (Array_load (Naked_int64s, Scalar, Mutable), array, index)
| Naked_nativeints ->
Binary (Array_load (Naked_nativeints, Mutable), array, index)
Binary (Array_load (Naked_nativeints, Scalar, Mutable), array, index)

let array_set_unsafe ~array ~index ~new_value
(array_set_kind : Array_set_kind.t) : H.expr_primitive =
Expand All @@ -710,7 +785,7 @@ let array_set_unsafe ~array ~index ~new_value
| Naked_floats_to_be_unboxed -> unbox_float new_value
in
let array_set_kind = convert_intermediate_array_set_kind array_set_kind in
Ternary (Array_set array_set_kind, array, index, new_value)
Ternary (Array_set (array_set_kind, Scalar), array, index, new_value)

let[@inline always] match_on_array_ref_kind ~array array_ref_kind f :
H.expr_primitive =
Expand Down Expand Up @@ -1523,6 +1598,58 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
[ bytes_like_set_safe ~dbg ~size_int
~access_size:(One_twenty_eight { aligned })
Bigstring ~boxed bigstring index new_value ]
| Pfloat_array_load_128 { unsafe; mode }, [[array]; [index]] ->
check_float_array_optimisation_enabled "Pfloat_array_load_128";
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
Naked_floats array index ]
| Pfloatarray_load_128 { unsafe; mode }, [[array]; [index]]
| Punboxed_float_array_load_128 { unsafe; mode }, [[array]; [index]] ->
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
Naked_floats array index ]
| Pint_array_load_128 { unsafe; mode }, [[array]; [index]] ->
if Targetint.size <> 64
then Misc.fatal_error "[Pint_array_load_128]: immediates must be 64 bits.";
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
Immediates array index ]
| Punboxed_int64_array_load_128 { unsafe; mode }, [[array]; [index]] ->
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
Naked_int64s array index ]
| Punboxed_nativeint_array_load_128 { unsafe; mode }, [[array]; [index]] ->
if Targetint.size <> 64
then
Misc.fatal_error
"[Punboxed_nativeint_array_load_128]: nativeint must be 64 bits.";
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
Naked_nativeints array index ]
| Punboxed_int32_array_load_128 { unsafe; mode }, [[array]; [index]] ->
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
Naked_int32s array index ]
| Pfloat_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
check_float_array_optimisation_enabled "Pfloat_array_set_128";
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_floats array index
new_value ]
| Pfloatarray_set_128 { unsafe }, [[array]; [index]; [new_value]]
| Punboxed_float_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_floats array index
new_value ]
| Pint_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
if Targetint.size <> 64
then Misc.fatal_error "[Pint_array_set_128]: immediates must be 64 bits.";
[array_like_set_128 ~dbg ~size_int ~unsafe Immediates array index new_value]
| Punboxed_int64_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_int64s array index
new_value ]
| Punboxed_nativeint_array_set_128 { unsafe }, [[array]; [index]; [new_value]]
->
if Targetint.size <> 64
then
Misc.fatal_error
"[Punboxed_nativeint_array_load_128]: nativeint must be 64 bits.";
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_nativeints array index
new_value ]
| Punboxed_int32_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_int32s array index
new_value ]
| Pcompare_ints, [[i1]; [i2]] ->
[ tag_int
(Binary
Expand Down Expand Up @@ -1609,7 +1736,10 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Pasrbint _ | Pfield_computed _ | Pdivbint _ | Pmodbint _
| Psetfloatfield _ | Psetufloatfield _ | Pbintcomp _ | Punboxed_int_comp _
| Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
| Pbigstring_load_128 _
| Pbigstring_load_128 _ | Pfloatarray_load_128 _ | Pfloat_array_load_128 _
| Pint_array_load_128 _ | Punboxed_float_array_load_128 _
| Punboxed_int32_array_load_128 _ | Punboxed_int64_array_load_128 _
| Punboxed_nativeint_array_load_128 _
| Parrayrefu
( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _ )
Expand All @@ -1636,7 +1766,10 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Punboxedfloatarray_set _ | Punboxedintarray_set _ )
| Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ | Pbytes_set_128 _
| Pbigstring_set_16 _ | Pbigstring_set_32 _ | Pbigstring_set_64 _
| Pbigstring_set_128 _ | Patomic_cas ),
| Pbigstring_set_128 _ | Pfloatarray_set_128 _ | Pfloat_array_set_128 _
| Pint_array_set_128 _ | Punboxed_float_array_set_128 _
| Punboxed_int32_array_set_128 _ | Punboxed_int64_array_set_128 _
| Punboxed_nativeint_array_set_128 _ | Patomic_cas ),
( []
| [_]
| [_; _]
Expand Down
8 changes: 6 additions & 2 deletions middle_end/flambda2/parser/fexpr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ type string_accessor_width = Flambda_primitive.string_accessor_width =
| Sixty_four
| One_twenty_eight of { aligned : bool }

type array_accessor_width = Flambda_primitive.array_accessor_width =
| Scalar
| Vec128

type string_like_value = Flambda_primitive.string_like_value =
| String
| Bytes
Expand All @@ -344,7 +348,7 @@ type infix_binop =
| Float_comp of unit comparison_behaviour

type binop =
| Array_load of array_kind * mutability
| Array_load of array_kind * array_accessor_width * mutability
| Block_load of block_access_kind * mutability
| Phys_equal of equality_comparison
| Int_arith of standard_int * binary_int_arith_op
Expand All @@ -356,7 +360,7 @@ type binop =

type ternop =
(* CR mshinwell: Array_set should use "array_set_kind" *)
| Array_set of array_kind * init_or_assign
| Array_set of array_kind * array_accessor_width * init_or_assign
| Block_set of block_access_kind * init_or_assign
| Bytes_or_bigstring_set of bytes_like_value * string_accessor_width

Expand Down
Loading