Skip to content

Commit

Permalink
flambda-backend: Allow unboxed float32s in mixed blocks (#2550)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNumbat authored May 10, 2024
1 parent 66fbd07 commit bd65b13
Show file tree
Hide file tree
Showing 21 changed files with 4,537 additions and 2,365 deletions.
5 changes: 3 additions & 2 deletions lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ and block_shape =
value_kind list option

and flat_element = Types.flat_element =
Imm | Float | Float64 | Bits32 | Bits64 | Word
Imm | Float | Float64 | Float32 | Bits32 | Bits64 | Word
and flat_element_read =
| Flat_read of flat_element (* invariant: not [Float] *)
| Flat_read_float of alloc_mode
Expand Down Expand Up @@ -1249,7 +1249,7 @@ let get_mixed_block_element = Types.get_mixed_product_element
let flat_read_non_float flat_element =
match flat_element with
| Float -> Misc.fatal_error "flat_element_read_non_float Float"
| Imm | Float64 | Bits32 | Bits64 | Word as flat_element ->
| Imm | Float64 | Float32 | Bits32 | Bits64 | Word as flat_element ->
Flat_read flat_element

let flat_read_float alloc_mode = Flat_read_float alloc_mode
Expand Down Expand Up @@ -1789,6 +1789,7 @@ let layout_of_mixed_field (kind : mixed_block_read) =
match proj with
| Imm -> layout_int
| Float64 -> layout_unboxed_float Pfloat64
| Float32 -> layout_unboxed_float Pfloat32
| Bits32 -> layout_unboxed_int32
| Bits64 -> layout_unboxed_int64
| Word -> layout_unboxed_nativeint
Expand Down
2 changes: 1 addition & 1 deletion lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ and block_shape =
value_kind list option

and flat_element = Types.flat_element =
Imm | Float | Float64 | Bits32 | Bits64 | Word
Imm | Float | Float64 | Float32 | Bits32 | Bits64 | Word
and flat_element_read = private
| Flat_read of flat_element (* invariant: not [Float] *)
| Flat_read_float of alloc_mode
Expand Down
2 changes: 1 addition & 1 deletion lambda/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2183,7 +2183,7 @@ let get_expr_args_record ~scopes head (arg, _mut, sort, layout) rem =
else
let read =
match flat_suffix.(pos - value_prefix_len) with
| Imm | Float64 | Bits32 | Bits64 | Word as non_float ->
| Imm | Float64 | Float32 | Bits32 | Bits64 | Word as non_float ->
flat_read_non_float non_float
| Float ->
(* TODO: could optimise to Alloc_local sometimes *)
Expand Down
5 changes: 1 addition & 4 deletions lambda/translcore.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ let layout_pat sort p = layout p.pat_env p.pat_loc sort p.pat_type

let check_record_field_sort loc sort =
match Jkind.Sort.get_default_value sort with
| Value | Float64 | Bits32 | Bits64 | Word -> ()
| Float32 ->
(* CR mslater: (float32) float32# records *)
Misc.fatal_error "Found unboxed float32 record field."
| Value | Float64 | Float32 | Bits32 | Bits64 | Word -> ()
| Void -> raise (Error (loc, Illegal_void_record_field))

(* Forward declaration -- to be filled in by Translmod.transl_module *)
Expand Down
67 changes: 52 additions & 15 deletions testsuite/tests/mixed-blocks/constructor_args.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
(* TEST
flags = "-extension layouts_beta";
flags = "-extension layouts_beta -extension small_numbers";
include beta;
flambda2;
{
native;
Expand All @@ -11,6 +12,7 @@
(*****************************************)
(* Prelude: Functions on unboxed numbers *)

module Float32_u = Beta.Float32_u
module Float_u = Stdlib__Float_u
module Int32_u = Stdlib__Int32_u
module Int64_u = Stdlib__Int64_u
Expand All @@ -34,6 +36,9 @@ type t =
| Mixed6 of float * int32# * float#
| Mixed7 of float * int64# * float# * nativeint#
| Mixed8 of float * int32# * float# * int64# * float#
| Mixed9 of float * float# * float32#
| Mixed10 of float * float32# * float# * int64# * float#
| Mixed11 of float * int32# * float32# * float# * int64# * nativeint#
| Uniform2 of float * float

type t_ext = ..
Expand All @@ -47,6 +52,9 @@ type t_ext +=
| Ext_mixed6 of float * int32# * float#
| Ext_mixed7 of float * int64# * float# * nativeint#
| Ext_mixed8 of float * int32# * float# * int64# * float#
| Ext_mixed9 of float * float# * float32#
| Ext_mixed10 of float * float32# * float# * int64# * float#
| Ext_mixed11 of float * int32# * float32# * float# * int64# * nativeint#

let sprintf = Printf.sprintf

Expand Down Expand Up @@ -75,6 +83,17 @@ let to_string = function
sprintf "Mixed8 (%f, %i, %f, %i, %f)"
x1 (Int32_u.to_int x2) (Float_u.to_float x3) (Int64_u.to_int x4)
(Float_u.to_float x5)
| Mixed9 (x1, x2, x3) ->
sprintf "Mixed9 (%f, %f, %f)" x1 (Float_u.to_float x2)
(Float_u.to_float (Float32_u.to_float x3))
| Mixed10 (x1, x2, x3, x4, x5) ->
sprintf "Mixed10 (%f, %f, %f, %i, %f)"
x1 (Float_u.to_float (Float32_u.to_float x2)) (Float_u.to_float x3)
(Int64_u.to_int x4) (Float_u.to_float x5)
| Mixed11 (x1, x2, x3, x4, x5, x6) ->
sprintf "Mixed11 (%f, %i, %f, %f, %i, %i)"
x1 (Int32_u.to_int x2) (Float_u.to_float (Float32_u.to_float x3))
(Float_u.to_float x4) (Int64_u.to_int x5) (Nativeint_u.to_int x6)
| Uniform2 (x1, x2) -> sprintf "Uniform2 (%f, %f)" x1 x2

let ext_to_string = function
Expand All @@ -100,6 +119,17 @@ let ext_to_string = function
sprintf "Ext_mixed8 (%f, %i, %f, %i, %f)"
x1 (Int32_u.to_int x2) (Float_u.to_float x3) (Int64_u.to_int x4)
(Float_u.to_float x5)
| Ext_mixed9 (x1, x2, x3) ->
sprintf "Ext_mixed9 (%f, %f, %f)" x1 (Float_u.to_float x2)
(Float_u.to_float (Float32_u.to_float x3))
| Ext_mixed10 (x1, x2, x3, x4, x5) ->
sprintf "Ext_mixed10 (%f, %f, %f, %i, %f)"
x1 (Float_u.to_float (Float32_u.to_float x2)) (Float_u.to_float x3)
(Int64_u.to_int x4) (Float_u.to_float x5)
| Ext_mixed11 (x1, x2, x3, x4, x5, x6) ->
sprintf "Ext_mixed11 (%f, %i, %f, %f, %i, %i)"
x1 (Int32_u.to_int x2) (Float_u.to_float (Float32_u.to_float x3))
(Float_u.to_float x4) (Int64_u.to_int x5) (Nativeint_u.to_int x6)
| _ -> "<ext>"

let print t = print_endline (" " ^ to_string t)
Expand Down Expand Up @@ -128,12 +158,12 @@ let () = run #17.0
exercise an optimization code path.
*)

let sum uf uf' f f' i i32 i64 i_n =
let sum uf uf' f f' i i32 i64 i_n f32 =
Float_u.to_float uf +. Float_u.to_float uf' +. f +. f' +.
Int32_u.to_float i32 +. Int64_u.to_float i64 +. Nativeint_u.to_float i_n
+. float_of_int i
+. float_of_int i +. (Float_u.to_float (Float32_u.to_float f32))

let construct_and_destruct uf uf' f f' i i32 i64 i_n =
let construct_and_destruct uf uf' f f' i i32 i64 i_n f32 =
let Constant = Constant in
let Uniform1 f = Uniform1 f in
let Mixed1 uf = Mixed1 uf in
Expand All @@ -144,6 +174,9 @@ let construct_and_destruct uf uf' f f' i i32 i64 i_n =
let Mixed6 (f, i32, uf) = Mixed6 (f, i32, uf) in
let Mixed7 (f, i64, uf, i_n) = Mixed7 (f, i64, uf, i_n) in
let Mixed8 (f, i32, uf, i64, uf') = Mixed8 (f, i32, uf, i64, uf') in
let Mixed9 (f, uf, f32) = Mixed9 (f, uf, f32) in
let Mixed10 (f, f32, uf, i64, uf') = Mixed10 (f, f32, uf, i64, uf') in
let Mixed11 (f, i32, f32, uf, i64, i_n) = Mixed11 (f, i32, f32, uf, i64, i_n) in
let Ext_mixed1 uf = Ext_mixed1 uf in
let Ext_mixed2 (f, uf) = Ext_mixed2 (f, uf) in
let Ext_mixed3 (f, uf, uf') = Ext_mixed3 (f, uf, uf') in
Expand All @@ -152,8 +185,11 @@ let construct_and_destruct uf uf' f f' i i32 i64 i_n =
let Ext_mixed6 (f, i32, uf) = Ext_mixed6 (f, i32, uf) in
let Ext_mixed7 (f, i64, uf, i_n) = Ext_mixed7 (f, i64, uf, i_n) in
let Ext_mixed8 (f, i32, uf, i64, uf') = Ext_mixed8 (f, i32, uf, i64, uf') in
let Ext_mixed9 (f, uf, f32) = Ext_mixed9 (f, uf, f32) in
let Ext_mixed10 (f, f32, uf, i64, uf') = Ext_mixed10 (f, f32, uf, i64, uf') in
let Ext_mixed11 (f, i32, f32, uf, i64, i_n) = Ext_mixed11 (f, i32, f32, uf, i64, i_n) in
let Uniform2 (f, f') = Uniform2 (f, f') in
sum uf uf' f f' i i32 i64 i_n
sum uf uf' f f' i i32 i64 i_n f32
[@@ocaml.warning "-partial-match"]

let () =
Expand All @@ -165,10 +201,11 @@ let () =
and i32 = #12l
and i64 = #42L
and i_n = #56n
and f32 = #1.2s
in
let () =
let sum1 = sum uf uf' f f' i i32 i64 i_n in
let sum2 = construct_and_destruct uf uf' f f' i i32 i64 i_n in
let sum1 = sum uf uf' f f' i i32 i64 i_n f32 in
let sum2 = construct_and_destruct uf uf' f f' i i32 i64 i_n f32 in
Printf.printf
"Test (construct and destruct): %f = %f (%s)\n"
sum1
Expand Down Expand Up @@ -218,7 +255,7 @@ let _ =
let go x y z =
let f =
match x with
| Mixed5 (f1, uf1, i, i32_1, i_n, i64) ->
| Mixed11 (f1, i32_1, f32, uf1, i64, i_n) ->
(* Close over the fields we projected out *)
(fun () ->
match y, z with
Expand All @@ -228,7 +265,6 @@ let go x y z =
Mixed3 (f2, uf2, uf3) ->
[ f1;
Float_u.to_float uf1;
float_of_int i;
Int32_u.to_float i32_1;
Nativeint_u.to_float i_n;
Int64_u.to_float i64;
Expand All @@ -238,6 +274,7 @@ let go x y z =
f3;
Float_u.to_float uf4;
Int32_u.to_float i32_2;
Float32.to_float (Float32_u.to_float32 f32);
]
| _ -> assert false
)
Expand All @@ -249,7 +286,6 @@ let test () =
let f1 = 4.0
and f2 = 42.0
and f3 = 36.0
and i = 3
and i32_1 = #3l
and i32_2 = -#10l
and i64 = -#20L
Expand All @@ -258,8 +294,9 @@ let test () =
and uf2 = #32.0
and uf3 = #47.5
and uf4 = #47.8
and f32 = #1.2s
in
let x = Mixed5 (f1, uf1, i, i32_1, i_n, i64) in
let x = Mixed11 (f1, i32_1, f32, uf1, i64, i_n) in
let y = Mixed3 (f2, uf2, uf3) in
let z = Mixed4 (f3, uf4, i32_2) in
(* These results should match as [go] is symmetric in
Expand Down Expand Up @@ -292,11 +329,10 @@ let go_recursive x y z =
let rec f_odd n =
if n < 7 then f_even (n+1)
else match x with
| Mixed5 (f1, uf1, i, i32_1, i_n, i64) ->
| Mixed11 (f1, i32_1, f32, uf1, i64, i_n) ->
[ float_of_int n;
f1;
Float_u.to_float uf1;
float_of_int i;
Int32_u.to_float i32_1;
Nativeint_u.to_float i_n;
Int64_u.to_float i64;
Expand All @@ -306,6 +342,7 @@ let go_recursive x y z =
f3;
Float_u.to_float uf4;
Int32_u.to_float i32_2;
Float32.to_float (Float32_u.to_float32 f32);
]
| _ -> assert false
and f_even n = f_odd (n+1) in
Expand All @@ -318,7 +355,6 @@ let test_recursive () =
let f1 = 4.0
and f2 = 42.0
and f3 = 36.0
and i = 3
and i32_1 = #3l
and i32_2 = -#10l
and i64 = -#20L
Expand All @@ -327,8 +363,9 @@ let test_recursive () =
and uf2 = #32.0
and uf3 = #47.5
and uf4 = #47.8
and f32 = #1.2s
in
let x = Mixed5 (f1, uf1, i, i32_1, i_n, i64) in
let x = Mixed11 (f1, i32_1, f32, uf1, i64, i_n) in
let y = Mixed3 (f2, uf2, uf3) in
let z = Mixed4 (f3, uf4, i32_2) in
(* These results should match as [go_recursive] is symmetric in
Expand Down
6 changes: 3 additions & 3 deletions testsuite/tests/mixed-blocks/constructor_args.reference
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Test (construction)
Ext_mixed1 8.000000
Mixed2 (3.000000, 4.500000)
Mixed3 (6.000000, 17.000000, 5.000000)
Test (construct and destruct): 149.700000 = 149.700000 (PASS)
Test (construct and destruct): 150.900000 = 150.900000 (PASS)
Test (mixed constructors in recursive groups):
Mixed1 4.000000
Mixed2 (5.000000, 4.000000)
Expand All @@ -19,7 +19,6 @@ Test (pattern matching).
4.000
17.000
3.000
3.000
174.000
-20.000
42.000
Expand All @@ -28,13 +27,13 @@ Test (pattern matching).
36.000
47.800
-10.000
1.200
Test (pattern matching, recursive closure).
Contents of fields:
7.000
4.000
17.000
3.000
3.000
174.000
-20.000
42.000
Expand All @@ -43,3 +42,4 @@ Test (pattern matching, recursive closure).
36.000
47.800
-10.000
1.200
Loading

0 comments on commit bd65b13

Please sign in to comment.