Skip to content

Commit

Permalink
Float32 array load/store primitives (ocaml-flambda#2697)
Browse files Browse the repository at this point in the history
* start

* progress

* impl

* missing fns

* bigarrays

* tests

* fix r5

* fix cvtsi2ss instruction

* fix bounds check tests

* address comments

* Code review

* format

---------

Co-authored-by: Mark Shinwell <mshinwell@pm.me>
  • Loading branch information
TheNumbat and mshinwell authored Jul 9, 2024
1 parent 8f14849 commit ddc1ec9
Show file tree
Hide file tree
Showing 31 changed files with 2,106 additions and 68 deletions.
11 changes: 11 additions & 0 deletions backend/cmm_helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,7 @@ let curry_function_sym function_kind arity result =
let bigarray_elt_size_in_bytes : Lambda.bigarray_kind -> int = function
| Pbigarray_unknown -> assert false
| Pbigarray_float32 -> 4
| Pbigarray_float32_t -> 4
| Pbigarray_float64 -> 8
| Pbigarray_sint8 -> 1
| Pbigarray_uint8 -> 1
Expand All @@ -1719,6 +1720,7 @@ let bigarray_elt_size_in_bytes : Lambda.bigarray_kind -> int = function
let bigarray_word_kind : Lambda.bigarray_kind -> memory_chunk = function
| Pbigarray_unknown -> assert false
| Pbigarray_float32 -> Single { reg = Float64 }
| Pbigarray_float32_t -> Single { reg = Float32 }
| Pbigarray_float64 -> Double
| Pbigarray_sint8 -> Byte_signed
| Pbigarray_uint8 -> Byte_unsigned
Expand Down Expand Up @@ -2173,6 +2175,15 @@ let unaligned_set_64 ptr idx newval dbg =
[add_int (add_int ptr idx dbg) (cconst_int 7) dbg; b8],
dbg ) ) ) )

let unaligned_load_f32 ptr idx dbg =
Cop (mk_load_mut (Single { reg = Float32 }), [add_int ptr idx dbg], dbg)

let unaligned_set_f32 ptr idx newval dbg =
Cop
( Cstore (Single { reg = Float32 }, Assignment),
[add_int ptr idx dbg; newval],
dbg )

let unaligned_load_128 ptr idx dbg =
assert (size_vec128 = 16);
Cop (mk_load_mut Onetwentyeight_unaligned, [add_int ptr idx dbg], dbg)
Expand Down
5 changes: 5 additions & 0 deletions backend/cmm_helpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,11 @@ val unaligned_load_32 : expression -> expression -> Debuginfo.t -> expression
val unaligned_set_32 :
expression -> expression -> expression -> Debuginfo.t -> expression

val unaligned_load_f32 : expression -> expression -> Debuginfo.t -> expression

val unaligned_set_f32 :
expression -> expression -> expression -> Debuginfo.t -> expression

val unaligned_load_64 : expression -> expression -> Debuginfo.t -> expression

val unaligned_set_64 :
Expand Down
22 changes: 12 additions & 10 deletions middle_end/flambda2/from_lambda/closure_conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -930,16 +930,18 @@ let close_primitive acc env ~let_bound_ids_with_kinds named
| Pandbint _ | Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _
| Pasrbint _ | Pbintcomp _ | Punboxed_int_comp _ | Pbigarrayref _
| Pbigarrayset _ | Pbigarraydim _ | Pstring_load_16 _ | Pstring_load_32 _
| Pstring_load_64 _ | Pstring_load_128 _ | Pbytes_load_16 _
| Pbytes_load_32 _ | Pbytes_load_64 _ | Pbytes_load_128 _
| Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ | Pbytes_set_128 _
| Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
| Pbigstring_load_128 _ | Pbigstring_set_16 _ | Pbigstring_set_32 _
| Pbigstring_set_64 _ | Pbigstring_set_128 _ | Pfloatarray_load_128 _
| Pfloat_array_load_128 _ | Pint_array_load_128 _
| Punboxed_float_array_load_128 _ | Punboxed_int32_array_load_128 _
| Punboxed_int64_array_load_128 _ | Punboxed_nativeint_array_load_128 _
| Pfloatarray_set_128 _ | Pfloat_array_set_128 _ | Pint_array_set_128 _
| Pstring_load_f32 _ | Pstring_load_64 _ | Pstring_load_128 _
| Pbytes_load_16 _ | Pbytes_load_32 _ | Pbytes_load_f32 _
| Pbytes_load_64 _ | Pbytes_load_128 _ | Pbytes_set_16 _ | Pbytes_set_32 _
| Pbytes_set_f32 _ | Pbytes_set_64 _ | Pbytes_set_128 _
| Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_f32 _
| Pbigstring_load_64 _ | Pbigstring_load_128 _ | Pbigstring_set_16 _
| Pbigstring_set_32 _ | Pbigstring_set_f32 _ | Pbigstring_set_64 _
| Pbigstring_set_128 _ | Pfloatarray_load_128 _ | Pfloat_array_load_128 _
| Pint_array_load_128 _ | Punboxed_float_array_load_128 _
| Punboxed_int32_array_load_128 _ | Punboxed_int64_array_load_128 _
| Punboxed_nativeint_array_load_128 _ | Pfloatarray_set_128 _
| Pfloat_array_set_128 _ | Pint_array_set_128 _
| Punboxed_float_array_set_128 _ | Punboxed_int32_array_set_128 _
| Punboxed_int64_array_set_128 _ | Punboxed_nativeint_array_set_128 _
| Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _ | Popaque _
Expand Down
57 changes: 47 additions & 10 deletions middle_end/flambda2/from_lambda/lambda_to_flambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ let transform_primitive env (prim : L.primitive) args loc =
Primitive
(L.Pnot, [L.Lprim (Punboxed_float_comp (bf, CFge), args, loc)], loc)
| Pbigarrayref (_unsafe, num_dimensions, kind, layout), args -> (
(* CR mshinwell: factor out with the [Pbigarrayset] case *)
match
P.Bigarray_kind.from_lambda kind, P.Bigarray_layout.from_lambda layout
with
Expand All @@ -299,7 +300,19 @@ let transform_primitive env (prim : L.primitive) args loc =
if 1 <= num_dimensions && num_dimensions <= 3
then
let arity = 1 + num_dimensions in
let name = "caml_ba_get_" ^ string_of_int num_dimensions in
let is_float32_t =
match kind with
| Pbigarray_float32_t -> "float32_"
| Pbigarray_unknown | Pbigarray_float32 | Pbigarray_float64
| Pbigarray_sint8 | Pbigarray_uint8 | Pbigarray_sint16
| Pbigarray_uint16 | Pbigarray_int32 | Pbigarray_int64
| Pbigarray_caml_int | Pbigarray_native_int | Pbigarray_complex32
| Pbigarray_complex64 ->
""
in
let name =
"caml_ba_" ^ is_float32_t ^ "get_" ^ string_of_int num_dimensions
in
let desc = Lambda.simple_prim_on_values ~name ~arity ~alloc:true in
Primitive (L.Pccall desc, args, loc)
else
Expand All @@ -316,7 +329,19 @@ let transform_primitive env (prim : L.primitive) args loc =
if 1 <= num_dimensions && num_dimensions <= 3
then
let arity = 2 + num_dimensions in
let name = "caml_ba_set_" ^ string_of_int num_dimensions in
let is_float32_t =
match kind with
| Pbigarray_float32_t -> "float32_"
| Pbigarray_unknown | Pbigarray_float32 | Pbigarray_float64
| Pbigarray_sint8 | Pbigarray_uint8 | Pbigarray_sint16
| Pbigarray_uint16 | Pbigarray_int32 | Pbigarray_int64
| Pbigarray_caml_int | Pbigarray_native_int | Pbigarray_complex32
| Pbigarray_complex64 ->
""
in
let name =
"caml_ba_" ^ is_float32_t ^ "set_" ^ string_of_int num_dimensions
in
let desc = Lambda.simple_prim_on_values ~name ~arity ~alloc:true in
Primitive (L.Pccall desc, args, loc)
else
Expand Down Expand Up @@ -589,22 +614,27 @@ let primitive_can_raise (prim : Lambda.primitive) =
| Pstringrefs | Pbytesrefs | Pbytessets
| Pstring_load_16 false
| Pstring_load_32 (false, _)
| Pstring_load_f32 (false, _)
| Pstring_load_64 (false, _)
| Pstring_load_128 { unsafe = false; _ }
| Pbytes_load_16 false
| Pbytes_load_32 (false, _)
| Pbytes_load_f32 (false, _)
| Pbytes_load_64 (false, _)
| Pbytes_load_128 { unsafe = false; _ }
| Pbytes_set_16 false
| Pbytes_set_32 false
| Pbytes_set_f32 false
| Pbytes_set_64 false
| Pbytes_set_128 { unsafe = false; _ }
| Pbigstring_load_16 { unsafe = false }
| Pbigstring_load_32 { unsafe = false; mode = _; boxed = _ }
| Pbigstring_load_f32 { unsafe = false; mode = _; boxed = _ }
| Pbigstring_load_64 { unsafe = false; mode = _; boxed = _ }
| Pbigstring_load_128 { unsafe = false; _ }
| Pbigstring_set_16 { unsafe = false }
| Pbigstring_set_32 { unsafe = false; boxed = _ }
| Pbigstring_set_f32 { unsafe = false; boxed = _ }
| Pbigstring_set_64 { unsafe = false; boxed = _ }
| Pbigstring_set_128 { unsafe = false; _ }
| Pfloatarray_load_128 { unsafe = false; _ }
Expand Down Expand Up @@ -662,37 +692,44 @@ let primitive_can_raise (prim : Lambda.primitive) =
| Pbigarrayref
( true,
_,
( Pbigarray_float32 | Pbigarray_float64 | Pbigarray_sint8
| Pbigarray_uint8 | Pbigarray_sint16 | Pbigarray_uint16
| Pbigarray_int32 | Pbigarray_int64 | Pbigarray_caml_int
| Pbigarray_native_int | Pbigarray_complex32 | Pbigarray_complex64 ),
( Pbigarray_float32 | Pbigarray_float32_t | Pbigarray_float64
| Pbigarray_sint8 | Pbigarray_uint8 | Pbigarray_sint16
| Pbigarray_uint16 | Pbigarray_int32 | Pbigarray_int64
| Pbigarray_caml_int | Pbigarray_native_int | Pbigarray_complex32
| Pbigarray_complex64 ),
_ )
| Pbigarrayset
( true,
_,
( Pbigarray_float32 | Pbigarray_float64 | Pbigarray_sint8
| Pbigarray_uint8 | Pbigarray_sint16 | Pbigarray_uint16
| Pbigarray_int32 | Pbigarray_int64 | Pbigarray_caml_int
| Pbigarray_native_int | Pbigarray_complex32 | Pbigarray_complex64 ),
( Pbigarray_float32 | Pbigarray_float32_t | Pbigarray_float64
| Pbigarray_sint8 | Pbigarray_uint8 | Pbigarray_sint16
| Pbigarray_uint16 | Pbigarray_int32 | Pbigarray_int64
| Pbigarray_caml_int | Pbigarray_native_int | Pbigarray_complex32
| Pbigarray_complex64 ),
(Pbigarray_c_layout | Pbigarray_fortran_layout) )
| Pstring_load_16 true
| Pstring_load_32 (true, _)
| Pstring_load_f32 (true, _)
| Pstring_load_64 (true, _)
| Pstring_load_128 { unsafe = true; _ }
| Pbytes_load_16 true
| Pbytes_load_32 (true, _)
| Pbytes_load_f32 (true, _)
| Pbytes_load_64 (true, _)
| Pbytes_load_128 { unsafe = true; _ }
| Pbytes_set_16 true
| Pbytes_set_32 true
| Pbytes_set_f32 true
| Pbytes_set_64 true
| Pbytes_set_128 { unsafe = true; _ }
| Pbigstring_load_16 { unsafe = true }
| Pbigstring_load_32 { unsafe = true; mode = _; boxed = _ }
| Pbigstring_load_f32 { unsafe = true; mode = _; boxed = _ }
| Pbigstring_load_64 { unsafe = true; mode = _; boxed = _ }
| Pbigstring_load_128 { unsafe = true; _ }
| Pbigstring_set_16 { unsafe = true }
| Pbigstring_set_32 { unsafe = true; boxed = _ }
| Pbigstring_set_f32 { unsafe = true; boxed = _ }
| Pbigstring_set_64 { unsafe = true; boxed = _ }
| Pbigstring_set_128 { unsafe = true; _ }
| Pfloatarray_load_128 { unsafe = true; _ }
Expand Down
Loading

0 comments on commit ddc1ec9

Please sign in to comment.