16
16
17
17
[@@@ ocaml.warning " +a-4-30-40-41-42" ]
18
18
19
+ open ! Flambda. Import
20
+
19
21
module KP = Kinded_parameter
20
22
module T = Flambda_type
21
23
module TE = Flambda_type. Typing_env
22
24
23
25
module rec Downwards_env : sig
24
26
include Simplify_env_and_result_intf. Downwards_env
25
27
with type result := Result. t
28
+ with type lifted_constant := Lifted_constant. t
26
29
end = struct
27
30
type t = {
28
31
backend : (module Flambda2_backend_intf .S );
@@ -32,26 +35,30 @@ end = struct
32
35
can_inline : bool ;
33
36
inlining_depth_increment : int ;
34
37
float_const_prop : bool ;
38
+ code : Function_params_and_body .t Code_id.Map .t ;
35
39
}
36
40
37
41
let print ppf { backend = _ ; round; typing_env;
38
42
inlined_debuginfo; can_inline;
39
43
inlining_depth_increment; float_const_prop;
44
+ code;
40
45
} =
41
46
Format. fprintf ppf " @[<hov 1>(\
42
47
@[<hov 1>(round@ %d)@]@ \
43
48
@[<hov 1>(typing_env@ %a)@]@ \
44
49
@[<hov 1>(inlined_debuginfo@ %a)@]@ \
45
50
@[<hov 1>(can_inline@ %b)@]@ \
46
51
@[<hov 1>(inlining_depth_increment@ %d)@]@ \
47
- @[<hov 1>(float_const_prop@ %b)@]\
52
+ @[<hov 1>(float_const_prop@ %b)@] \
53
+ @[<hov 1>(code@ %a)@]\
48
54
)@]"
49
55
round
50
56
TE. print typing_env
51
57
Debuginfo. print inlined_debuginfo
52
58
can_inline
53
59
inlining_depth_increment
54
60
float_const_prop
61
+ (Code_id.Map. print Function_params_and_body. print) code
55
62
56
63
let invariant _t = ()
57
64
@@ -65,6 +72,7 @@ end = struct
65
72
can_inline = true ;
66
73
inlining_depth_increment = 0 ;
67
74
float_const_prop;
75
+ code = Code_id.Map. empty;
68
76
}
69
77
70
78
let resolver t = TE. resolver t.typing_env
@@ -92,7 +100,7 @@ end = struct
92
100
let enter_closure { backend; round; typing_env;
93
101
inlined_debuginfo = _ ; can_inline;
94
102
inlining_depth_increment = _ ;
95
- float_const_prop;
103
+ float_const_prop; code;
96
104
} =
97
105
{ backend;
98
106
round;
@@ -101,6 +109,7 @@ end = struct
101
109
can_inline;
102
110
inlining_depth_increment = 0 ;
103
111
float_const_prop;
112
+ code;
104
113
}
105
114
106
115
let define_variable t var kind =
@@ -131,6 +140,8 @@ end = struct
131
140
let typing_env = TE. add_equation t.typing_env (Name. var var) ty in
132
141
{ t with typing_env; }
133
142
143
+ let mem_name t name = TE. mem t.typing_env name
144
+
134
145
let find_name t name =
135
146
match TE. find t.typing_env name with
136
147
| exception Not_found ->
@@ -172,6 +183,8 @@ end = struct
172
183
in
173
184
{ t with typing_env; }
174
185
186
+ let mem_symbol t sym = mem_name t (Name. symbol sym)
187
+
175
188
let find_symbol t sym = find_name t (Name. symbol sym)
176
189
177
190
let define_name t name kind =
@@ -269,18 +282,21 @@ end = struct
269
282
(* CR mshinwell: Convert [Typing_env] to map from [Simple]s. *)
270
283
| Const _ -> ()
271
284
272
- let add_inlined_debuginfo' t dbg =
273
- Debuginfo. concat t.inlined_debuginfo dbg
274
-
275
- let add_inlined_debuginfo t dbg =
285
+ let define_code t id code =
286
+ if Code_id.Map. mem id t.code then begin
287
+ Misc. fatal_errorf " Code ID %a is already defined, cannot redefine to@ %a"
288
+ Code_id. print id
289
+ Function_params_and_body. print code
290
+ end ;
276
291
{ t with
277
- inlined_debuginfo = add_inlined_debuginfo' t dbg
292
+ code = Code_id.Map. add id code t.code;
278
293
}
279
294
280
- let disable_function_inlining t =
281
- { t with
282
- can_inline = false ;
283
- }
295
+ let find_code t id =
296
+ match Code_id.Map. find id t.code with
297
+ | exception Not_found ->
298
+ Misc. fatal_errorf " Code ID %a not bound" Code_id. print id
299
+ | code -> code
284
300
285
301
(* CR mshinwell: The label should state what order is expected. *)
286
302
let add_lifted_constants t ~lifted =
@@ -289,17 +305,71 @@ end = struct
289
305
(Format.pp_print_list ~pp_sep:Format.pp_print_space
290
306
Lifted_constant.print) lifted;
291
307
*)
292
- let typing_env =
293
- List. fold_left (fun typing_env lifted_constant ->
294
- Lifted_constant. introduce lifted_constant typing_env)
295
- (typing_env t)
296
- (List. rev lifted)
297
- in
298
- with_typing_env t typing_env
308
+ let module LC = Lifted_constant in
309
+ List. fold_left (fun denv lifted_constant ->
310
+ let denv_at_definition = LC. denv_at_definition lifted_constant in
311
+ let types_of_symbols = LC. types_of_symbols lifted_constant in
312
+ let definition = LC. definition lifted_constant in
313
+ let being_defined =
314
+ Flambda_static.Program_body.Definition. being_defined definition
315
+ in
316
+ let already_bound =
317
+ Symbol.Set. filter (fun sym -> mem_symbol denv sym)
318
+ being_defined
319
+ in
320
+ if Symbol.Set. equal being_defined already_bound then denv
321
+ else if not (Symbol.Set. is_empty already_bound) then
322
+ Misc. fatal_errorf " Expected all or none of the following symbols \
323
+ to be found:@ %a@ denv:@ %a"
324
+ LC. print lifted_constant
325
+ print denv
326
+ else
327
+ let typing_env =
328
+ Symbol.Map. fold (fun sym typ typing_env ->
329
+ let sym =
330
+ Name_in_binding_pos. create (Name. symbol sym) Name_mode. normal
331
+ in
332
+ TE. add_definition typing_env sym (T. kind typ))
333
+ types_of_symbols
334
+ denv.typing_env
335
+ in
336
+ let typing_env =
337
+ Symbol.Map. fold (fun sym typ typing_env ->
338
+ let sym = Name. symbol sym in
339
+ let env_extension =
340
+ T. make_suitable_for_environment typ
341
+ denv_at_definition.typing_env
342
+ ~suitable_for: typing_env
343
+ ~bind_to: sym
344
+ in
345
+ TE. add_env_extension typing_env ~env_extension )
346
+ types_of_symbols
347
+ typing_env
348
+ in
349
+ Code_id.Map. fold (fun code_id params_and_body denv ->
350
+ define_code denv code_id params_and_body)
351
+ (LC. pieces_of_code lifted_constant)
352
+ (with_typing_env denv typing_env))
353
+ t
354
+ (List. rev lifted)
299
355
300
356
(* CR mshinwell: Think more about this -- may be re-traversing long lists *)
301
357
let add_lifted_constants_from_r t r =
302
358
add_lifted_constants t ~lifted: (Result. get_lifted_constants r)
359
+
360
+ let add_inlined_debuginfo' t dbg =
361
+ Debuginfo. concat t.inlined_debuginfo dbg
362
+
363
+ let add_inlined_debuginfo t dbg =
364
+ { t with
365
+ inlined_debuginfo = add_inlined_debuginfo' t dbg
366
+ }
367
+
368
+ let disable_function_inlining t =
369
+ { t with
370
+ can_inline = false ;
371
+ }
372
+
303
373
end and Upwards_env : sig
304
374
include Simplify_env_and_result_intf. Upwards_env
305
375
with type downwards_env := Downwards_env. t
@@ -483,6 +553,7 @@ end = struct
483
553
| rewrite -> Some rewrite
484
554
end and Result : sig
485
555
include Simplify_env_and_result_intf. Result
556
+ with type lifted_constant := Lifted_constant. t
486
557
end = struct
487
558
type t =
488
559
{ resolver : (Export_id .t -> Flambda_type .t option );
@@ -521,4 +592,52 @@ end = struct
521
592
{ t with
522
593
lifted_constants_innermost_last = [] ;
523
594
}
595
+ end and Lifted_constant : sig
596
+ include Simplify_env_and_result_intf. Lifted_constant
597
+ with type downwards_env := Downwards_env. t
598
+ end = struct
599
+ module Definition = Flambda_static.Program_body. Definition
600
+
601
+ type t = {
602
+ denv : Downwards_env .t ;
603
+ definition : Definition .t ;
604
+ types_of_symbols : Flambda_type .t Symbol.Map .t ;
605
+ pieces_of_code : Function_params_and_body .t Code_id.Map .t ;
606
+ }
607
+
608
+ let print ppf
609
+ { denv = _ ; definition; types_of_symbols = _ ; pieces_of_code = _ ; } =
610
+ Format. fprintf ppf " @[<hov 1>(\
611
+ @[<hov 1>(definition@ %a)@]\
612
+ )@]"
613
+ Definition. print definition
614
+
615
+ let create denv definition ~types_of_symbols ~pieces_of_code =
616
+ let being_defined = Definition. being_defined definition in
617
+ if not (Symbol.Set. subset (Symbol.Map. keys types_of_symbols) being_defined)
618
+ then begin
619
+ Misc. fatal_errorf " [types_of_symbols]:@ %a@ does not cover all symbols \
620
+ in the [Definition]:@ %a"
621
+ (Symbol.Map. print T. print) types_of_symbols
622
+ Definition. print definition
623
+ end ;
624
+ let code_being_defined = Definition. code_being_defined definition in
625
+ if not (Code_id.Set. subset (Code_id.Map. keys pieces_of_code)
626
+ code_being_defined)
627
+ then begin
628
+ Misc. fatal_errorf " [pieces_of_code]:@ %a@ does not cover all code IDs \
629
+ in the [Definition]:@ %a"
630
+ (Code_id.Map. print Function_params_and_body. print) pieces_of_code
631
+ Definition. print definition
632
+ end ;
633
+ { denv;
634
+ definition;
635
+ types_of_symbols;
636
+ pieces_of_code;
637
+ }
638
+
639
+ let denv_at_definition t = t.denv
640
+ let definition t = t.definition
641
+ let types_of_symbols t = t.types_of_symbols
642
+ let pieces_of_code t = t.pieces_of_code
524
643
end
0 commit comments