Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor multi-word array bounds check code #3069

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 57 additions & 41 deletions middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -620,42 +620,58 @@ let bytes_like_set ~dbg ~unsafe
check_access ~dbg ~size_int ~access_size ~primitive:unsafe_set bytes
~index_kind index

(* Array bounds checks *)

let multiple_word_array_access_validity_condition array ~size_int
array_length_kind index_kind ~width_in_scalars ~index =
let length_tagged = H.Prim (Unary (Array_length array_length_kind, array)) in
if width_in_scalars < 1
then Misc.fatal_errorf "Invalid width_in_scalars value: %d" width_in_scalars
else if width_in_scalars = 1
then
(* Ensure good code generation in the common case. *)
check_bound ~index_kind ~bound_kind:Tagged_immediate ~index
~bound:length_tagged
else
let length_untagged = untag_int length_tagged in
let reduced_length_untagged =
H.Prim
(Binary
( Int_arith (Naked_immediate, Sub),
length_untagged,
Simple
(Simple.untagged_const_int
(Targetint_31_63.of_int (width_in_scalars - 1))) ))
in
(* We need to convert the length into a naked_nativeint because the
optimised version of the max_with_zero function needs to be on
machine-width integers to work (or at least on an integer number of bytes
to work). *)
let reduced_length_nativeint =
H.Prim
(Unary
( Num_conv { src = Naked_immediate; dst = Naked_nativeint },
reduced_length_untagged ))
in
let nativeint_bound = max_with_zero ~size_int reduced_length_nativeint in
check_bound ~index_kind ~bound_kind:Naked_nativeint ~index
~bound:nativeint_bound

(* Array vector load/store *)

let array_vector_access_width_in_scalars (array_kind : P.Array_kind.t) =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s | Naked_float32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."

let array_vector_access_validity_condition array ~size_int
(array_kind : P.Array_kind.t) index =
let width_in_scalars =
match array_kind with
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
| Naked_int32s | Naked_float32s -> 4
| Values ->
Misc.fatal_error
"Attempted to load/store a SIMD vector from/to a value array."
in
let length_untagged =
untag_int (H.Prim (Unary (Array_length (Array_kind array_kind), array)))
in
let reduced_length_untagged =
H.Prim
(Binary
( Int_arith (Naked_immediate, Sub),
length_untagged,
Simple
(Simple.untagged_const_int
(Targetint_31_63.of_int (width_in_scalars - 1))) ))
in
(* We need to convert the length into a naked_nativeint because the optimised
version of the max_with_zero function needs to be on machine-width integers
to work (or at least on an integer number of bytes to work). *)
let reduced_length_nativeint =
H.Prim
(Unary
( Num_conv { src = Naked_immediate; dst = Naked_nativeint },
reduced_length_untagged ))
in
let nativeint_bound = max_with_zero ~size_int reduced_length_nativeint in
check_bound ~index_kind:Ptagged_int_index ~bound_kind:Naked_nativeint ~index
~bound:nativeint_bound
let width_in_scalars = array_vector_access_width_in_scalars array_kind in
multiple_word_array_access_validity_condition array ~size_int
(Array_kind array_kind) Ptagged_int_index ~width_in_scalars ~index

let check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive
: H.expr_primitive =
Expand Down Expand Up @@ -806,16 +822,16 @@ let bigarray_set ~dbg ~unsafe kind layout b indexes value =

(* Array accesses *)
let array_access_validity_condition array array_kind index
~(index_kind : L.array_index_kind) =
let arr_len_as_tagged_imm = H.Prim (Unary (Array_length array_kind, array)) in
[ check_bound ~index_kind ~bound_kind:Tagged_immediate ~index
~bound:arr_len_as_tagged_imm ]
~(index_kind : L.array_index_kind) ~size_int =
[ multiple_word_array_access_validity_condition array ~size_int array_kind
index_kind ~width_in_scalars:1 ~index ]

let check_array_access ~dbg ~array array_kind ~index ~index_kind primitive :
H.expr_primitive =
let check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
primitive : H.expr_primitive =
checked_access ~primitive
~conditions:
(array_access_validity_condition array array_kind index ~index_kind)
(array_access_validity_condition array array_kind index ~index_kind
~size_int)
~dbg

let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t)
Expand Down Expand Up @@ -1574,7 +1590,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
~current_region) ]
| Parrayrefs (array_ref_kind, index_kind), [[array]; [index]] ->
let array_kind = convert_array_ref_kind_for_length array_ref_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_ref_kind ~array array_ref_kind
(array_load_unsafe ~array
~index:(convert_index_to_tagged_int ~index ~index_kind)
Expand All @@ -1586,7 +1602,7 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
~new_value) ]
| Parraysets (array_set_kind, index_kind), [[array]; [index]; [new_value]] ->
let array_kind = convert_array_set_kind_for_length array_set_kind in
[ check_array_access ~dbg ~array array_kind ~index ~index_kind
[ check_array_access ~dbg ~array array_kind ~index ~index_kind ~size_int
(match_on_array_set_kind ~array array_set_kind
(array_set_unsafe ~array
~index:(convert_index_to_tagged_int ~index ~index_kind)
Expand Down