From e73bed11f0f6ab7fe28ed9f43cda86e344d6dc62 Mon Sep 17 00:00:00 2001 From: Mark Shinwell Date: Mon, 23 Sep 2024 12:11:06 +0100 Subject: [PATCH] Refactor multi-word array bounds check code --- .../lambda_to_flambda_primitives.ml | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml b/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml index 97cd1ace190..2959699848e 100644 --- a/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml +++ b/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml @@ -281,8 +281,8 @@ let convert_array_kind_to_duplicate_array_kind (kind : L.array_kind) : Duplicate_array_kind (Naked_nativeints { length = None }) | Pgcscannableproductarray _ | Pgcignorableproductarray _ -> Misc.fatal_error - "Lambda_to_flambda_primitives.convert_array_kind_to_duplicate_array_kind\ - : unimplemented" + "Lambda_to_flambda_primitives.convert_array_kind_to_duplicate_array_kind: \ + unimplemented" let convert_field_read_semantics (sem : L.field_read_semantics) : Mutability.t = match sem with Reads_agree -> Immutable | Reads_vary -> Mutable @@ -633,20 +633,12 @@ let bytes_like_set ~dbg ~unsafe check_access ~dbg ~size_int ~access_size ~primitive:unsafe_set bytes ~index_kind index -(* Array vector load/store *) +(* Array bounds checks *) -let array_vector_access_validity_condition array ~size_int - (array_kind : P.Array_kind.t) index = - let width_in_scalars = - match array_kind with - | Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2 - | Naked_int32s | Naked_float32s -> 4 - | Values -> - Misc.fatal_error - "Attempted to load/store a SIMD vector from/to a value array." - in +let multiple_word_array_access_validity_condition array ~size_int + array_length_kind index_kind ~width_in_scalars ~index = let length_untagged = - untag_int (H.Prim (Unary (Array_length (Array_kind array_kind), array))) + untag_int (H.Prim (Unary (Array_length array_length_kind, array))) in let reduced_length_untagged = H.Prim @@ -667,9 +659,25 @@ let array_vector_access_validity_condition array ~size_int reduced_length_untagged )) in let nativeint_bound = max_with_zero ~size_int reduced_length_nativeint in - check_bound ~index_kind:Ptagged_int_index ~bound_kind:Naked_nativeint ~index + check_bound ~index_kind ~bound_kind:Naked_nativeint ~index ~bound:nativeint_bound +(* Array vector load/store *) + +let array_vector_access_width_in_scalars (array_kind : P.Array_kind.t) = + match array_kind with + | Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2 + | Naked_int32s | Naked_float32s -> 4 + | Values -> + Misc.fatal_error + "Attempted to load/store a SIMD vector from/to a value array." + +let array_vector_access_validity_condition array ~size_int + (array_kind : P.Array_kind.t) index = + let width_in_scalars = array_vector_access_width_in_scalars array_kind in + multiple_word_array_access_validity_condition array ~size_int + (Array_kind array_kind) Ptagged_int_index ~width_in_scalars ~index + let check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive : H.expr_primitive = checked_access ~primitive @@ -819,16 +827,16 @@ let bigarray_set ~dbg ~unsafe kind layout b indexes value = (* Array accesses *) let array_access_validity_condition array array_kind index - ~(index_kind : L.array_index_kind) = - let arr_len_as_tagged_imm = H.Prim (Unary (Array_length array_kind, array)) in - [ check_bound ~index_kind ~bound_kind:Tagged_immediate ~index - ~bound:arr_len_as_tagged_imm ] + ~(index_kind : L.array_index_kind) ~size_int = + [ multiple_word_array_access_validity_condition array ~size_int array_kind + index_kind ~width_in_scalars:1 ~index ] -let check_array_access ~dbg ~array array_kind ~index ~index_kind primitive : - H.expr_primitive = +let check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int + primitive : H.expr_primitive = checked_access ~primitive ~conditions: - (array_access_validity_condition array array_kind index ~index_kind) + (array_access_validity_condition array array_kind index ~index_kind + ~size_int) ~dbg let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t) @@ -1590,7 +1598,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list) ~current_region) ] | Parrayrefs (array_ref_kind, index_kind), [[array]; [index]] -> let array_kind = convert_array_ref_kind_for_length array_ref_kind in - [ check_array_access ~dbg ~array array_kind ~index ~index_kind + [ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int (match_on_array_ref_kind ~array array_ref_kind (array_load_unsafe ~array ~index:(convert_index_to_tagged_int ~index ~index_kind) @@ -1602,7 +1610,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list) ~new_value) ] | Parraysets (array_set_kind, index_kind), [[array]; [index]; [new_value]] -> let array_kind = convert_array_set_kind_for_length array_set_kind in - [ check_array_access ~dbg ~array array_kind ~index ~index_kind + [ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int (match_on_array_set_kind ~array array_set_kind (array_set_unsafe ~array ~index:(convert_index_to_tagged_int ~index ~index_kind) @@ -1971,14 +1979,14 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list) | Parrayrefu ( ( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _ | Punboxedfloatarray_ref _ - | Punboxedintarray_ref _ - | Pgcscannableproductarray_ref _ | Pgcignorableproductarray_ref _ ), + | Punboxedintarray_ref _ | Pgcscannableproductarray_ref _ + | Pgcignorableproductarray_ref _ ), _ ) | Parrayrefs ( ( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _ | Punboxedfloatarray_ref _ - | Punboxedintarray_ref _ - | Pgcscannableproductarray_ref _ | Pgcignorableproductarray_ref _ ), + | Punboxedintarray_ref _ | Pgcscannableproductarray_ref _ + | Pgcignorableproductarray_ref _ ), _ ) | Pcompare_ints | Pcompare_floats _ | Pcompare_bints _ | Patomic_exchange | Patomic_fetch_add ), @@ -1995,14 +2003,14 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list) | Parraysetu ( ( Pgenarray_set _ | Paddrarray_set _ | Pintarray_set | Pfloatarray_set | Punboxedfloatarray_set _ - | Punboxedintarray_set _ - | Pgcscannableproductarray_set _ | Pgcignorableproductarray_set _ ), + | Punboxedintarray_set _ | Pgcscannableproductarray_set _ + | Pgcignorableproductarray_set _ ), _ ) | Parraysets ( ( Pgenarray_set _ | Paddrarray_set _ | Pintarray_set | Pfloatarray_set | Punboxedfloatarray_set _ - | Punboxedintarray_set _ - | Pgcscannableproductarray_set _ | Pgcignorableproductarray_set _ ), + | Punboxedintarray_set _ | Pgcscannableproductarray_set _ + | Pgcignorableproductarray_set _ ), _ ) | Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_f32 _ | Pbytes_set_64 _ | Pbytes_set_128 _ | Pbigstring_set_16 _ | Pbigstring_set_32 _