Skip to content

Commit

Permalink
Refactor multi-word array bounds check code
Browse files Browse the repository at this point in the history
  • Loading branch information
mshinwell committed Sep 23, 2024
1 parent 19a85ed commit e73bed1
Showing 1 changed file with 40 additions and 32 deletions.
72 changes: 40 additions & 32 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ let convert_array_kind_to_duplicate_array_kind (kind : L.array_kind) :
Duplicate_array_kind (Naked_nativeints { length = None })
| Pgcscannableproductarray _ | Pgcignorableproductarray _ ->
Misc.fatal_error
"Lambda_to_flambda_primitives.convert_array_kind_to_duplicate_array_kind\
: unimplemented"
"Lambda_to_flambda_primitives.convert_array_kind_to_duplicate_array_kind: \
unimplemented"

let convert_field_read_semantics (sem : L.field_read_semantics) : Mutability.t =
match sem with Reads_agree -> Immutable | Reads_vary -> Mutable
Expand Down Expand Up @@ -633,20 +633,12 @@ let bytes_like_set ~dbg ~unsafe
check_access ~dbg ~size_int ~access_size ~primitive:unsafe_set bytes
~index_kind index

(* Array vector load/store *)
(* Array bounds checks *)

let array_vector_access_validity_condition array ~size_int
(array_kind : P.Array_kind.t) index =
let width_in_scalars =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s | Naked_float32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."
in
let multiple_word_array_access_validity_condition array ~size_int
array_length_kind index_kind ~width_in_scalars ~index =
let length_untagged =
untag_int (H.Prim (Unary (Array_length (Array_kind array_kind), array)))
untag_int (H.Prim (Unary (Array_length array_length_kind, array)))
in
let reduced_length_untagged =
H.Prim
Expand All @@ -667,9 +659,25 @@ let array_vector_access_validity_condition array ~size_int
reduced_length_untagged ))
in
let nativeint_bound = max_with_zero ~size_int reduced_length_nativeint in
check_bound ~index_kind:Ptagged_int_index ~bound_kind:Naked_nativeint ~index
check_bound ~index_kind ~bound_kind:Naked_nativeint ~index
~bound:nativeint_bound

(* Array vector load/store *)

let array_vector_access_width_in_scalars (array_kind : P.Array_kind.t) =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s | Naked_float32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."

let array_vector_access_validity_condition array ~size_int
(array_kind : P.Array_kind.t) index =
let width_in_scalars = array_vector_access_width_in_scalars array_kind in
multiple_word_array_access_validity_condition array ~size_int
(Array_kind array_kind) Ptagged_int_index ~width_in_scalars ~index

let check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive
: H.expr_primitive =
checked_access ~primitive
Expand Down Expand Up @@ -819,16 +827,16 @@ let bigarray_set ~dbg ~unsafe kind layout b indexes value =

(* Array accesses *)
let array_access_validity_condition array array_kind index
~(index_kind : L.array_index_kind) =
let arr_len_as_tagged_imm = H.Prim (Unary (Array_length array_kind, array)) in
[ check_bound ~index_kind ~bound_kind:Tagged_immediate ~index
~bound:arr_len_as_tagged_imm ]
~(index_kind : L.array_index_kind) ~size_int =
[ multiple_word_array_access_validity_condition array ~size_int array_kind
index_kind ~width_in_scalars:1 ~index ]

let check_array_access ~dbg ~array array_kind ~index ~index_kind primitive :
H.expr_primitive =
let check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
primitive : H.expr_primitive =
checked_access ~primitive
~conditions:
(array_access_validity_condition array array_kind index ~index_kind)
(array_access_validity_condition array array_kind index ~index_kind
~size_int)
~dbg

let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t)
Expand Down Expand Up @@ -1590,7 +1598,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
~current_region) ]
| Parrayrefs (array_ref_kind, index_kind), [[array]; [index]] ->
let array_kind = convert_array_ref_kind_for_length array_ref_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_ref_kind ~array array_ref_kind
(array_load_unsafe ~array
~index:(convert_index_to_tagged_int ~index ~index_kind)
Expand All @@ -1602,7 +1610,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
~new_value) ]
| Parraysets (array_set_kind, index_kind), [[array]; [index]; [new_value]] ->
let array_kind = convert_array_set_kind_for_length array_set_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_set_kind ~array array_set_kind
(array_set_unsafe ~array
~index:(convert_index_to_tagged_int ~index ~index_kind)
Expand Down Expand Up @@ -1971,14 +1979,14 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Parrayrefu
( ( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref
| Pfloatarray_ref _ | Punboxedfloatarray_ref _
| Punboxedintarray_ref _
| Pgcscannableproductarray_ref _ | Pgcignorableproductarray_ref _ ),
| Punboxedintarray_ref _ | Pgcscannableproductarray_ref _
| Pgcignorableproductarray_ref _ ),
_ )
| Parrayrefs
( ( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref
| Pfloatarray_ref _ | Punboxedfloatarray_ref _
| Punboxedintarray_ref _
| Pgcscannableproductarray_ref _ | Pgcignorableproductarray_ref _ ),
| Punboxedintarray_ref _ | Pgcscannableproductarray_ref _
| Pgcignorableproductarray_ref _ ),
_ )
| Pcompare_ints | Pcompare_floats _ | Pcompare_bints _ | Patomic_exchange
| Patomic_fetch_add ),
Expand All @@ -1995,14 +2003,14 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Parraysetu
( ( Pgenarray_set _ | Paddrarray_set _ | Pintarray_set
| Pfloatarray_set | Punboxedfloatarray_set _
| Punboxedintarray_set _
| Pgcscannableproductarray_set _ | Pgcignorableproductarray_set _ ),
| Punboxedintarray_set _ | Pgcscannableproductarray_set _
| Pgcignorableproductarray_set _ ),
_ )
| Parraysets
( ( Pgenarray_set _ | Paddrarray_set _ | Pintarray_set
| Pfloatarray_set | Punboxedfloatarray_set _
| Punboxedintarray_set _
| Pgcscannableproductarray_set _ | Pgcignorableproductarray_set _ ),
| Punboxedintarray_set _ | Pgcscannableproductarray_set _
| Pgcignorableproductarray_set _ ),
_ )
| Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_f32 _ | Pbytes_set_64 _
| Pbytes_set_128 _ | Pbigstring_set_16 _ | Pbigstring_set_32 _
Expand Down

0 comments on commit e73bed1

Please sign in to comment.