Skip to content

Improved reinterpret casts for integers and floats #2686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 10, 2024
2 changes: 2 additions & 0 deletions backend/cmm_helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3648,6 +3648,8 @@ let binary op ~dbg x y = Cop (op, [x; y], dbg)

let int64_as_float = unary (Creinterpret_cast Float_of_int64)

let float_as_int64 = unary (Creinterpret_cast Int64_of_float)

let int_of_float = unary (Cstatic_cast (Int_of_float Float64))

let float_of_int = unary (Cstatic_cast (Float_of_int Float64))
Expand Down
2 changes: 2 additions & 0 deletions backend/cmm_helpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,8 @@ val asr_int_caml_raw : dbg:Debuginfo.t -> expression -> expression -> expression

val int64_as_float : dbg:Debuginfo.t -> expression -> expression

val float_as_int64 : dbg:Debuginfo.t -> expression -> expression

(** Conversions functions between integers and floats. *)

val int_of_float : dbg:Debuginfo.t -> expression -> expression
Expand Down
33 changes: 18 additions & 15 deletions middle_end/flambda2/from_lambda/closure_conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -585,27 +585,20 @@ let close_c_call acc env ~loc ~let_bound_ids_with_kinds
in
let call args acc =
(* Some C primitives have implementations within Flambda itself. *)
match prim_native_name with
| "caml_int64_float_of_bits_unboxed"
(* There is only one case where this operation is not the identity: on
32-bit pre-EABI ARM platforms. It is very unlikely anyone would still be
using one of those, but just in case, we only optimise this primitive on
64-bit systems. (There is no easy way here of detecting just the specific
ARM case in question.) *)
when match Targetint_32_64.num_bits with
| Thirty_two -> false
| Sixty_four -> true -> (
let[@inline] unboxed_int64_to_and_from_unboxed_float ~src_kind ~dst_kind ~op
=
if prim_arity <> 1
then Misc.fatal_errorf "Expected arity one for %s" prim_native_name
else
match prim_native_repr_args, prim_native_repr_res with
| [(_, Unboxed_integer Pint64)], (_, Unboxed_float Pfloat64) -> (
| [(_, src)], (_, dst)
when Stdlib.( = ) src src_kind && Stdlib.( = ) dst dst_kind -> (
match args with
| [arg] ->
let result = Variable.create "reinterpreted_int64" in
let result = Variable.create "reinterpreted" in
let result' = Bound_var.create result Name_mode.normal in
let bindable = Bound_pattern.singleton result' in
let prim = P.Unary (Reinterpret_int64_as_float, arg) in
let prim = P.Unary (Reinterpret_64_bit_word op, arg) in
let acc, return_result =
Apply_cont_with_acc.create acc return_continuation
~args:[Simple.var result]
Expand All @@ -621,7 +614,15 @@ let close_c_call acc env ~loc ~let_bound_ids_with_kinds
Misc.fatal_errorf "Expected one arg for %s" prim_native_name)
| _, _ ->
Misc.fatal_errorf "Wrong argument and/or result kind(s) for %s"
prim_native_name)
prim_native_name
in
match prim_native_name with
| "caml_int64_float_of_bits_unboxed" ->
unboxed_int64_to_and_from_unboxed_float ~src_kind:(Unboxed_integer Pint64)
~dst_kind:(Unboxed_float Pfloat64) ~op:Unboxed_int64_as_unboxed_float64
| "caml_int64_bits_of_float_unboxed" ->
unboxed_int64_to_and_from_unboxed_float ~src_kind:(Unboxed_float Pfloat64)
~dst_kind:(Unboxed_integer Pint64) ~op:Unboxed_float64_as_unboxed_int64
| _ ->
let callee = Simple.symbol call_symbol in
let apply =
Expand Down Expand Up @@ -948,7 +949,9 @@ let close_primitive acc env ~let_bound_ids_with_kinds named
| Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _
| Punboxed_product_field _ | Pget_header _ | Prunstack | Pperform
| Presume | Preperform | Patomic_exchange | Patomic_cas
| Patomic_fetch_add | Pdls_get | Patomic_load _ ->
| Patomic_fetch_add | Pdls_get | Patomic_load _
| Preinterpret_tagged_int63_as_unboxed_int64
| Preinterpret_unboxed_int64_as_tagged_int63 ->
(* Inconsistent with outer match *)
assert false
in
Expand Down
4 changes: 3 additions & 1 deletion middle_end/flambda2/from_lambda/lambda_to_flambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,9 @@ let primitive_can_raise (prim : Lambda.primitive) =
false
| Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _ -> false
| Prunstack | Pperform | Presume | Preperform -> true (* XXX! *)
| Pdls_get -> false
| Pdls_get | Preinterpret_tagged_int63_as_unboxed_int64
| Preinterpret_unboxed_int64_as_tagged_int63 ->
false

type non_tail_continuation =
Acc.t ->
Expand Down
18 changes: 17 additions & 1 deletion middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,20 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Patomic_fetch_add, [[atomic]; [i]] ->
[Binary (Atomic_fetch_and_add, atomic, i)]
| Pdls_get, _ -> [Nullary Dls_get]
| Preinterpret_unboxed_int64_as_tagged_int63, [[i]] ->
if not (Target_system.is_64_bit ())
then
Misc.fatal_error
"Preinterpret_unboxed_int64_as_tagged_int63 can only be used on 64-bit \
targets";
[Unary (Reinterpret_64_bit_word Unboxed_int64_as_tagged_int63, i)]
| Preinterpret_tagged_int63_as_unboxed_int64, [[i]] ->
if not (Target_system.is_64_bit ())
then
Misc.fatal_error
"Preinterpret_tagged_int63_as_unboxed_int64 can only be used on 64-bit \
targets";
[Unary (Reinterpret_64_bit_word Tagged_int63_as_unboxed_int64, i)]
| ( ( Pmodint Unsafe
| Pdivbint { is_safe = Unsafe; size = _; mode = _ }
| Pmodbint { is_safe = Unsafe; size = _; mode = _ }
Expand Down Expand Up @@ -1917,7 +1931,9 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Punbox_float _
| Pbox_float (_, _)
| Punbox_int _ | Pbox_int _ | Punboxed_product_field _ | Pget_header _
| Pufloatfield _ | Patomic_load _ | Pmixedfield _ ),
| Pufloatfield _ | Patomic_load _ | Pmixedfield _
| Preinterpret_unboxed_int64_as_tagged_int63
| Preinterpret_tagged_int63_as_unboxed_int64 ),
([] | _ :: _ :: _ | [([] | _ :: _ :: _)]) ) ->
Misc.fatal_errorf
"Closure_conversion.convert_primitive: Wrong arity for unary primitive \
Expand Down
4 changes: 4 additions & 0 deletions middle_end/flambda2/numbers/numeric_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ module type Float_by_bit_pattern = sig

val of_bits : bits -> t

val to_bits : t -> bits

val of_string : string -> t

val to_float : t -> float
Expand Down Expand Up @@ -226,6 +228,8 @@ struct

let of_bits bits = bits

let to_bits bits = bits

let of_string str = Bits.of_string str

let to_float t = Bits.float_of_bits t
Expand Down
2 changes: 2 additions & 0 deletions middle_end/flambda2/numbers/numeric_types.mli
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ module type Float_by_bit_pattern = sig

val of_bits : bits -> t

val to_bits : t -> bits

val of_string : string -> t

val to_float : t -> float
Expand Down
2 changes: 1 addition & 1 deletion middle_end/flambda2/parser/flambda_to_fexpr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ let unop env (op : Flambda_primitive.unary_primitive) : Fexpr.unop =
| String_length string_or_bytes -> String_length string_or_bytes
| Boolean_not -> Boolean_not
| Int_as_pointer _ | Duplicate_block _ | Duplicate_array _ | Bigarray_length _
| Float_arith _ | Reinterpret_int64_as_float | Is_boxed_float | Obj_dup
| Float_arith _ | Reinterpret_64_bit_word _ | Is_boxed_float | Obj_dup
| Get_header | Atomic_load _ ->
Misc.fatal_errorf "TODO: Unary primitive: %a"
Flambda_primitive.Without_args.print
Expand Down
130 changes: 111 additions & 19 deletions middle_end/flambda2/simplify/simplify_unary_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -431,24 +431,115 @@ let simplify_boolean_not dacc ~original_term ~arg:_ ~arg_ty ~result_var =
SPR.create original_term ~try_reify:false dacc
| Invalid -> SPR.create_invalid dacc

let simplify_reinterpret_int64_as_float dacc ~original_term ~arg:_ ~arg_ty
~result_var =
let typing_env = DE.typing_env (DA.denv dacc) in
let proof = T.meet_naked_int64s typing_env arg_ty in
match proof with
| Known_result int64s ->
let floats =
Int64.Set.fold
(fun int64 floats -> Float.Set.add (Float.of_bits int64) floats)
int64s Float.Set.empty
in
let ty = T.these_naked_floats floats in
let dacc = DA.add_variable dacc result_var ty in
SPR.create original_term ~try_reify:true dacc
| Need_meet ->
let dacc = DA.add_variable dacc result_var T.any_naked_float in
SPR.create original_term ~try_reify:false dacc
| Invalid -> SPR.create_invalid dacc
module Make_simplify_reinterpret_64_bit_word (P : sig
module Src : Container_types.S

module Dst : Container_types.S

val prover : TE.t -> T.t -> Src.Set.t meet_shortcut

val convert : Src.t -> Dst.t

val these : Dst.Set.t -> T.t

val any_dst : T.t
end) =
struct
let simplify dacc ~original_term ~arg:_ ~arg_ty ~result_var =
let typing_env = DE.typing_env (DA.denv dacc) in
let proof = P.prover typing_env arg_ty in
match proof with
| Known_result src ->
let dst =
P.Src.Set.fold
(fun src dst -> P.Dst.Set.add (P.convert src) dst)
src P.Dst.Set.empty
in
let ty = P.these dst in
let dacc = DA.add_variable dacc result_var ty in
SPR.create original_term ~try_reify:true dacc
| Need_meet ->
let dacc = DA.add_variable dacc result_var P.any_dst in
SPR.create original_term ~try_reify:false dacc
| Invalid -> SPR.create_invalid dacc
end

module Simplify_reinterpret_unboxed_int64_as_unboxed_float64 =
Make_simplify_reinterpret_64_bit_word (struct
module Src = Int64
module Dst = Float

let prover = T.meet_naked_int64s

let convert = Float.of_bits

let these = T.these_naked_floats

let any_dst = T.any_naked_float
end)

module Simplify_reinterpret_unboxed_float64_as_unboxed_int64 =
Make_simplify_reinterpret_64_bit_word (struct
module Src = Float
module Dst = Int64

let prover = T.meet_naked_floats

let convert = Float.to_bits

let these = T.these_naked_int64s

let any_dst = T.any_naked_int64
end)

module Simplify_reinterpret_unboxed_int64_as_tagged_int63 =
Make_simplify_reinterpret_64_bit_word (struct
module Src = Int64
module Dst = Targetint_31_63

let prover = T.meet_naked_int64s

(* This primitive is logical OR with 1 on machine words, but here, we are
working in the tagged world. As such a different computation is
required. *)
let convert i = Targetint_31_63.of_int64 (Int64.shift_right_logical i 1)

let these = T.these_tagged_immediates

let any_dst = T.any_tagged_immediate
end)

module Simplify_reinterpret_tagged_int63_as_unboxed_int64 =
Make_simplify_reinterpret_64_bit_word (struct
module Src = Targetint_31_63
module Dst = Int64

let prover = T.meet_equals_tagged_immediates

(* This primitive is the identity on machine words, but as above, we are
working in the tagged world. *)
let convert i = Int64.add (Int64.mul (Targetint_31_63.to_int64 i) 2L) 1L

let these = T.these_naked_int64s

let any_dst = T.any_naked_int64
end)

let simplify_reinterpret_64_bit_word (reinterpret : P.Reinterpret_64_bit_word.t)
dacc ~original_term ~arg ~arg_ty ~result_var =
match reinterpret with
| Unboxed_int64_as_unboxed_float64 ->
Simplify_reinterpret_unboxed_int64_as_unboxed_float64.simplify dacc
~original_term ~arg ~arg_ty ~result_var
| Unboxed_float64_as_unboxed_int64 ->
Simplify_reinterpret_unboxed_float64_as_unboxed_int64.simplify dacc
~original_term ~arg ~arg_ty ~result_var
| Unboxed_int64_as_tagged_int63 ->
Simplify_reinterpret_unboxed_int64_as_tagged_int63.simplify dacc
~original_term ~arg ~arg_ty ~result_var
| Tagged_int63_as_unboxed_int64 ->
Simplify_reinterpret_tagged_int63_as_unboxed_int64.simplify dacc
~original_term ~arg ~arg_ty ~result_var

module Make_simplify_float_arith_op (FP : sig
module F : Numeric_types.Float_by_bit_pattern
Expand Down Expand Up @@ -712,7 +803,8 @@ let simplify_unary_primitive dacc original_prim (prim : P.unary_primitive) ~arg
| Naked_int64 -> Simplify_int_conv_naked_int64.simplify ~dst
| Naked_nativeint -> Simplify_int_conv_naked_nativeint.simplify ~dst)
| Boolean_not -> simplify_boolean_not
| Reinterpret_int64_as_float -> simplify_reinterpret_int64_as_float
| Reinterpret_64_bit_word reinterpret ->
simplify_reinterpret_64_bit_word reinterpret
| Is_boxed_float -> simplify_is_boxed_float
| Is_flat_float_array -> simplify_is_flat_float_array
| Int_as_pointer mode -> simplify_int_as_pointer ~mode
Expand Down
7 changes: 6 additions & 1 deletion middle_end/flambda2/terms/code_size.ml
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,12 @@ let unary_prim_size prim =
| Float_arith _ -> 2
| Num_conv { src; dst } -> arith_conversion_size src dst
| Boolean_not -> 1
| Reinterpret_int64_as_float -> 0
| Reinterpret_64_bit_word reinterpret -> (
match reinterpret with
| Tagged_int63_as_unboxed_int64 -> 0
| Unboxed_int64_as_tagged_int63 -> (* Needs a logical OR. *) 1
| Unboxed_int64_as_unboxed_float64 | Unboxed_float64_as_unboxed_int64 ->
(* Needs a move between register classes. *) 1)
| Unbox_number k -> unbox_number k
| Untag_immediate -> 1 (* 1 shift *)
| Box_number (k, _alloc_mode) -> box_number k
Expand Down
Loading
Loading