Skip to content

Commit

Permalink
Refactor multi-word array bounds check code (#3069)
Browse files Browse the repository at this point in the history
* Refactor multi-word array bounds check code

(cherry picked from commit e73bed1)

* Special case for width_in_scalars=1
  • Loading branch information
mshinwell authored and TheNumbat committed Oct 4, 2024
1 parent d329146 commit e356d9b
Showing 1 changed file with 57 additions and 41 deletions.
98 changes: 57 additions & 41 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -620,42 +620,58 @@ let bytes_like_set ~dbg ~unsafe
check_access ~dbg ~size_int ~access_size ~primitive:unsafe_set bytes
~index_kind index

(* Array bounds checks *)

let multiple_word_array_access_validity_condition array ~size_int
array_length_kind index_kind ~width_in_scalars ~index =
let length_tagged = H.Prim (Unary (Array_length array_length_kind, array)) in
if width_in_scalars < 1
then Misc.fatal_errorf "Invalid width_in_scalars value: %d" width_in_scalars
else if width_in_scalars = 1
then
(* Ensure good code generation in the common case. *)
check_bound ~index_kind ~bound_kind:Tagged_immediate ~index
~bound:length_tagged
else
let length_untagged = untag_int length_tagged in
let reduced_length_untagged =
H.Prim
(Binary
( Int_arith (Naked_immediate, Sub),
length_untagged,
Simple
(Simple.untagged_const_int
(Targetint_31_63.of_int (width_in_scalars - 1))) ))
in
(* We need to convert the length into a naked_nativeint because the
optimised version of the max_with_zero function needs to be on
machine-width integers to work (or at least on an integer number of bytes
to work). *)
let reduced_length_nativeint =
H.Prim
(Unary
( Num_conv { src = Naked_immediate; dst = Naked_nativeint },
reduced_length_untagged ))
in
let nativeint_bound = max_with_zero ~size_int reduced_length_nativeint in
check_bound ~index_kind ~bound_kind:Naked_nativeint ~index
~bound:nativeint_bound

(* Array vector load/store *)

let array_vector_access_width_in_scalars (array_kind : P.Array_kind.t) =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s | Naked_float32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."

let array_vector_access_validity_condition array ~size_int
(array_kind : P.Array_kind.t) index =
let width_in_scalars =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s | Naked_float32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."
in
let length_untagged =
untag_int (H.Prim (Unary (Array_length (Array_kind array_kind), array)))
in
let reduced_length_untagged =
H.Prim
(Binary
( Int_arith (Naked_immediate, Sub),
length_untagged,
Simple
(Simple.untagged_const_int
(Targetint_31_63.of_int (width_in_scalars - 1))) ))
in
(* We need to convert the length into a naked_nativeint because the optimised
version of the max_with_zero function needs to be on machine-width integers
to work (or at least on an integer number of bytes to work). *)
let reduced_length_nativeint =
H.Prim
(Unary
( Num_conv { src = Naked_immediate; dst = Naked_nativeint },
reduced_length_untagged ))
in
let nativeint_bound = max_with_zero ~size_int reduced_length_nativeint in
check_bound ~index_kind:Ptagged_int_index ~bound_kind:Naked_nativeint ~index
~bound:nativeint_bound
let width_in_scalars = array_vector_access_width_in_scalars array_kind in
multiple_word_array_access_validity_condition array ~size_int
(Array_kind array_kind) Ptagged_int_index ~width_in_scalars ~index

let check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive
: H.expr_primitive =
Expand Down Expand Up @@ -806,16 +822,16 @@ let bigarray_set ~dbg ~unsafe kind layout b indexes value =

(* Array accesses *)
let array_access_validity_condition array array_kind index
~(index_kind : L.array_index_kind) =
let arr_len_as_tagged_imm = H.Prim (Unary (Array_length array_kind, array)) in
[ check_bound ~index_kind ~bound_kind:Tagged_immediate ~index
~bound:arr_len_as_tagged_imm ]
~(index_kind : L.array_index_kind) ~size_int =
[ multiple_word_array_access_validity_condition array ~size_int array_kind
index_kind ~width_in_scalars:1 ~index ]

let check_array_access ~dbg ~array array_kind ~index ~index_kind primitive :
H.expr_primitive =
let check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
primitive : H.expr_primitive =
checked_access ~primitive
~conditions:
(array_access_validity_condition array array_kind index ~index_kind)
(array_access_validity_condition array array_kind index ~index_kind
~size_int)
~dbg

let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t)
Expand Down Expand Up @@ -1574,7 +1590,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
~current_region) ]
| Parrayrefs (array_ref_kind, index_kind), [[array]; [index]] ->
let array_kind = convert_array_ref_kind_for_length array_ref_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_ref_kind ~array array_ref_kind
(array_load_unsafe ~array
~index:(convert_index_to_tagged_int ~index ~index_kind)
Expand All @@ -1586,7 +1602,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
~new_value) ]
| Parraysets (array_set_kind, index_kind), [[array]; [index]; [new_value]] ->
let array_kind = convert_array_set_kind_for_length array_set_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_set_kind ~array array_set_kind
(array_set_unsafe ~array
~index:(convert_index_to_tagged_int ~index ~index_kind)
Expand Down

0 comments on commit e356d9b

Please sign in to comment.