Skip to content

Commit 25f9259

Browse files
committed
Compiler: make tailcall optim more robust
We introduce an intermediate block at the beginning of the closure (if necessary) and jump back to it when we see a tailcall. With this, we can: - remove the "skip_param" argument from "Flow.f" - simplify the O3 profile
1 parent 1fa0eca commit 25f9259

17 files changed

+4873
-4766
lines changed

compiler/lib/driver.ml

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ let flow p =
8787
if debug () then Format.eprintf "Data flow...@.";
8888
Flow.f p
8989

90-
let flow_simple p =
91-
if debug () then Format.eprintf "Data flow...@.";
92-
Flow.f ~skip_param:true p
93-
9490
let phi p =
9591
if debug () then Format.eprintf "Variable passing simplification...@.";
9692
Phisimpl.f p
@@ -161,7 +157,7 @@ let identity x = x
161157
let o1 : 'a -> 'a =
162158
print
163159
+> tailcall
164-
+> flow_simple (* flow simple to keep information for future tailcall opt *)
160+
+> flow
165161
+> specialize'
166162
+> eval
167163
+> inline (* inlining may reveal new tailcall opt *)
@@ -190,19 +186,7 @@ let o2 : 'a -> 'a = loop 10 "o1" o1 1 +> print
190186

191187
(* o3 *)
192188

193-
let round1 : 'a -> 'a =
194-
print
195-
+> tailcall
196-
+> inline (* inlining may reveal new tailcall opt *)
197-
+> deadcode (* deadcode required before flow simple -> provided by constant *)
198-
+> flow_simple (* flow simple to keep information for future tailcall opt *)
199-
+> specialize'
200-
+> eval
201-
+> identity
202-
203-
let round2 = flow +> specialize' +> eval +> deadcode +> o1
204-
205-
let o3 = loop 10 "tailcall+inline" round1 1 +> loop 10 "flow" round2 1 +> print
189+
let o3 = loop 10 "o1" o1 1 +> print
206190

207191
let generate
208192
d

compiler/lib/flow.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,9 @@ let program_escape defs known_origins { blocks; _ } =
289289

290290
(****)
291291

292-
let propagate2 ?(skip_param = false) defs known_origins possibly_mutable st x =
292+
let propagate2 defs known_origins possibly_mutable st x =
293293
match defs.(Var.idx x) with
294-
| Param -> skip_param
294+
| Param -> false
295295
| Phi s -> Var.Set.exists (fun y -> Var.Tbl.get st y) s
296296
| Expr e -> (
297297
match e with
@@ -318,11 +318,11 @@ end
318318

319319
module Solver2 = G.Solver (Domain2)
320320

321-
let solver2 ?skip_param vars deps defs known_origins possibly_mutable =
321+
let solver2 vars deps defs known_origins possibly_mutable =
322322
let g =
323323
{ G.domain = vars; G.iter_children = (fun f x -> Var.Set.iter f deps.(Var.idx x)) }
324324
in
325-
Solver2.f () g (propagate2 ?skip_param defs known_origins possibly_mutable)
325+
Solver2.f () g (propagate2 defs known_origins possibly_mutable)
326326

327327
let get_approx
328328
{ Info.info_defs = _; info_known_origins; info_maybe_unknown; _ }
@@ -492,7 +492,7 @@ let build_subst (info : Info.t) vars =
492492

493493
(****)
494494

495-
let f ?skip_param p =
495+
let f p =
496496
Code.invariant p;
497497
let t = Timer.make () in
498498
let t1 = Timer.make () in
@@ -505,7 +505,7 @@ let f ?skip_param p =
505505
let possibly_mutable = program_escape defs known_origins p in
506506
if times () then Format.eprintf " flow analysis 3: %a@." Timer.print t3;
507507
let t4 = Timer.make () in
508-
let maybe_unknown = solver2 ?skip_param vars deps defs known_origins possibly_mutable in
508+
let maybe_unknown = solver2 vars deps defs known_origins possibly_mutable in
509509
if times () then Format.eprintf " flow analysis 4: %a@." Timer.print t4;
510510
if debug ()
511511
then

compiler/lib/flow.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ val the_block_contents_of : Info.t -> Code.prim_arg -> Code.Var.t array option
6565

6666
val the_int : Info.t -> Code.prim_arg -> Targetint.t option
6767

68-
val f : ?skip_param:bool -> Code.program -> Code.program * Info.t
68+
val f : Code.program -> Code.program * Info.t

compiler/lib/tailcall.ml

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,53 +39,101 @@ let rec tail_call x f l =
3939
-> Some args
4040
| _ :: rem -> tail_call x f rem
4141

42-
let rewrite_block (f, f_params, f_pc, args) pc blocks =
42+
let rewrite_block (f, f_params, f_pc, used, stats) pc blocks =
4343
let block = Addr.Map.find pc blocks in
4444
match block.branch with
4545
| Return x -> (
4646
match tail_call x f block.body with
47-
| Some f_args when List.length f_params = List.length f_args ->
48-
let m = Subst.build_mapping f_params f_args in
49-
List.iter2 f_params f_args ~f:(fun p a -> Code.Var.propagate_name p a);
50-
Addr.Map.add
51-
pc
52-
{ params = block.params
53-
; body = remove_last block.body
54-
; branch = Branch (f_pc, List.map args ~f:(fun x -> Var.Map.find x m))
55-
}
56-
blocks
57-
| _ -> blocks)
47+
| Some f_args ->
48+
if List.compare_lengths f_params f_args = 0
49+
then (
50+
incr stats;
51+
List.iter2 f_params f_args ~f:(fun p a -> Code.Var.propagate_name p a);
52+
used := true;
53+
Addr.Map.add
54+
pc
55+
{ params = block.params
56+
; body = remove_last block.body
57+
; branch = Branch (f_pc, f_args)
58+
}
59+
blocks)
60+
else blocks
61+
| None -> blocks)
5862
| _ -> blocks
5963

6064
let rec traverse f pc visited blocks =
6165
if not (Addr.Set.mem pc visited)
6266
then
6367
let visited = Addr.Set.add pc visited in
64-
let blocks = rewrite_block f pc blocks in
65-
let visited, blocks =
66-
Code.fold_children_skip_try_body
67-
blocks
68-
pc
69-
(fun pc (visited, blocks) ->
70-
let visited, blocks = traverse f pc visited blocks in
71-
visited, blocks)
72-
(visited, blocks)
73-
in
74-
visited, blocks
68+
let blocks' = rewrite_block f pc blocks in
69+
if not (phys_equal blocks blocks')
70+
then (* the block was rewritten *)
71+
visited, blocks'
72+
else
73+
let blocks = blocks' in
74+
let visited, blocks =
75+
Code.fold_children_skip_try_body
76+
blocks
77+
pc
78+
(fun pc (visited, blocks) ->
79+
let visited, blocks = traverse f pc visited blocks in
80+
visited, blocks)
81+
(visited, blocks)
82+
in
83+
visited, blocks
7584
else visited, blocks
7685

7786
let f p =
87+
let free_pc = ref p.free_pc in
88+
let blocks = ref p.blocks in
89+
let stats = ref 0 in
7890
let t = Timer.make () in
79-
let blocks =
80-
fold_closures
81-
p
82-
(fun f params (pc, args) blocks ->
83-
match f with
84-
| Some f when List.length params = List.length args ->
85-
let _, blocks = traverse (f, params, pc, args) pc Addr.Set.empty blocks in
86-
blocks
87-
| _ -> blocks)
88-
p.blocks
89-
in
90-
if times () then Format.eprintf " tail calls: %a@." Timer.print t;
91-
{ p with blocks }
91+
Addr.Map.iter
92+
(fun pc _ ->
93+
let block = Addr.Map.find pc !blocks in
94+
let rewrite_body = ref false in
95+
let body =
96+
List.map block.body ~f:(function
97+
| Let (f, Closure (params, (pc_head, args))) as i ->
98+
if List.equal ~eq:Code.Var.equal params args
99+
then (
100+
blocks :=
101+
snd
102+
(traverse
103+
(f, params, pc_head, ref false, stats)
104+
pc_head
105+
Addr.Set.empty
106+
!blocks);
107+
i)
108+
else
109+
let intermediate_pc = !free_pc in
110+
let need_to_create_intermediate_block = ref false in
111+
blocks :=
112+
snd
113+
(traverse
114+
( f
115+
, params
116+
, intermediate_pc
117+
, need_to_create_intermediate_block
118+
, stats )
119+
pc_head
120+
Addr.Set.empty
121+
!blocks);
122+
if !need_to_create_intermediate_block
123+
then (
124+
incr free_pc;
125+
let new_params = List.map params ~f:Code.Var.fork in
126+
blocks :=
127+
Addr.Map.add
128+
intermediate_pc
129+
{ params; body = []; branch = Branch (pc_head, args) }
130+
!blocks;
131+
rewrite_body := true;
132+
Let (f, Closure (new_params, (intermediate_pc, new_params))))
133+
else i
134+
| i -> i)
135+
in
136+
if !rewrite_body then blocks := Addr.Map.add pc { block with body } !blocks)
137+
p.blocks;
138+
if times () then Format.eprintf " tail calls: %a #%d@." Timer.print t !stats;
139+
{ p with blocks = !blocks; free_pc = !free_pc }

compiler/tests-compiler/double-translation/effects_continuations.ml

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -101,42 +101,44 @@ let%expect_test "test-compiler/lib-effects/test1.ml" =
101101
[%expect
102102
{|
103103
function exceptions$0(s){
104-
try{var _K_ = caml_int_of_string(s), n = _K_;}
105-
catch(_N_){
106-
var _F_ = caml_wrap_exception(_N_);
107-
if(_F_[1] !== Stdlib[7]) throw caml_maybe_attach_backtrace(_F_, 0);
104+
try{var _C_ = caml_int_of_string(s), n = _C_;}
105+
catch(exn$0){
106+
var exn = caml_wrap_exception(exn$0);
107+
if(exn[1] !== Stdlib[7]) throw caml_maybe_attach_backtrace(exn, 0);
108108
var n = 0;
109109
}
110110
try{
111111
if(caml_string_equal(s, cst$0))
112112
throw caml_maybe_attach_backtrace(Stdlib[8], 1);
113-
var _J_ = 7, m = _J_;
113+
var _B_ = 7, m = _B_;
114114
}
115-
catch(_M_){
116-
var _G_ = caml_wrap_exception(_M_);
117-
if(_G_ !== Stdlib[8]) throw caml_maybe_attach_backtrace(_G_, 0);
115+
catch(exn){
116+
var exn$0 = caml_wrap_exception(exn);
117+
if(exn$0 !== Stdlib[8]) throw caml_maybe_attach_backtrace(exn$0, 0);
118118
var m = 0;
119119
}
120120
try{
121121
if(caml_string_equal(s, cst))
122122
throw caml_maybe_attach_backtrace(Stdlib[8], 1);
123-
var _I_ = [0, [0, caml_call1(Stdlib[79], cst_toto), n, m]];
124-
return _I_;
123+
var _A_ = [0, [0, caml_call1(Stdlib[79], cst_toto), n, m]];
124+
return _A_;
125125
}
126-
catch(_L_){
127-
var _H_ = caml_wrap_exception(_L_);
128-
if(_H_ === Stdlib[8]) return 0;
129-
throw caml_maybe_attach_backtrace(_H_, 0);
126+
catch(exn){
127+
var exn$1 = caml_wrap_exception(exn);
128+
if(exn$1 === Stdlib[8]) return 0;
129+
throw caml_maybe_attach_backtrace(exn$1, 0);
130130
}
131131
}
132132
//end
133133
function exceptions$1(s, cont){
134-
try{var _z_ = caml_int_of_string(s), n = _z_;}
135-
catch(_E_){
136-
var _A_ = caml_wrap_exception(_E_);
137-
if(_A_[1] !== Stdlib[7]){
138-
var raise$1 = caml_pop_trap();
139-
return raise$1(caml_maybe_attach_backtrace(_A_, 0));
134+
try{var _y_ = caml_int_of_string(s), n = _y_;}
135+
catch(exn){
136+
var exn$2 = caml_wrap_exception(exn);
137+
if(exn$2[1] !== Stdlib[7]){
138+
var
139+
raise$1 = caml_pop_trap(),
140+
exn$0 = caml_maybe_attach_backtrace(exn$2, 0);
141+
return raise$1(exn$0);
140142
}
141143
var n = 0;
142144
}
@@ -145,25 +147,25 @@ let%expect_test "test-compiler/lib-effects/test1.ml" =
145147
throw caml_maybe_attach_backtrace(Stdlib[8], 1);
146148
var _x_ = 7, m = _x_;
147149
}
148-
catch(_D_){
149-
var _y_ = caml_wrap_exception(_D_);
150-
if(_y_ !== Stdlib[8]){
151-
var raise$0 = caml_pop_trap();
152-
return raise$0(caml_maybe_attach_backtrace(_y_, 0));
150+
catch(exn$0){
151+
var exn$1 = caml_wrap_exception(exn$0);
152+
if(exn$1 !== Stdlib[8]){
153+
var raise$0 = caml_pop_trap(), exn = caml_maybe_attach_backtrace(exn$1, 0);
154+
return raise$0(exn);
153155
}
154156
var m = 0;
155157
}
156158
runtime.caml_push_trap
157-
(function(_C_){
158-
if(_C_ === Stdlib[8]) return cont(0);
159-
var raise = caml_pop_trap();
160-
return raise(caml_maybe_attach_backtrace(_C_, 0));
159+
(function(exn$0){
160+
if(exn$0 === Stdlib[8]) return cont(0);
161+
var raise = caml_pop_trap(), exn = caml_maybe_attach_backtrace(exn$0, 0);
162+
return raise(exn);
161163
});
162164
if(! caml_string_equal(s, cst))
163165
return caml_trampoline_cps_call2
164166
(Stdlib[79],
165167
cst_toto,
166-
function(_B_){caml_pop_trap(); return cont([0, [0, _B_, n, m]]);});
168+
function(_z_){caml_pop_trap(); return cont([0, [0, _z_, n, m]]);});
167169
var _w_ = Stdlib[8], raise = caml_pop_trap();
168170
return raise(caml_maybe_attach_backtrace(_w_, 1));
169171
}

0 commit comments

Comments
 (0)