forked from Ekdohibs/PolyGen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PolyLang.v
601 lines (545 loc) · 28.9 KB
/
PolyLang.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
(* *****************************************************************)
(* *)
(* Verified polyhedral AST generation *)
(* *)
(* Nathanaël Courant, Inria Paris *)
(* *)
(* Copyright Inria. All rights reserved. This file is distributed *)
(* under the terms of the GNU Lesser General Public License as *)
(* published by the Free Software Foundation, either version 2.1 *)
(* of the License, or (at your option) any later version. *)
(* *)
(* *****************************************************************)
Require Import ZArith.
Require Import List.
Require Import Bool.
Require Import Psatz.
Require Import Setoid Morphisms.
Require Import Misc.
Require Import Linalg.
Require Import Semantics.
Require Import Instr.
Open Scope Z_scope.
Open Scope list_scope.
(** * The semantics of polyhedral programs with schedules *)
Record Polyhedral_Instruction := {
pi_instr : instr ;
(** polyhedron = list constraint = list (list Z * Z)*)
pi_poly : polyhedron ;
pi_schedule : list (list Z * Z)%type ;
pi_transformation : list (list Z * Z)%type ;
}.
Print polyhedron.
Print constraint.
Definition Poly_Program := list Polyhedral_Instruction.
Definition scanned to_scan n p m q := to_scan m q && negb (is_eq p q && (n =? m)%nat).
Hint Unfold scanned.
(** 对于某个to_scan来说,它的闭包记录了所有(n, p)对,to_scan m q = true,首先会发生一个重复检查,
最终根据env_scan做判断,要求q在polyinstr[m]中,并且env保持一致.
这种扫描,其实允许了q重复的情况,就是相同点被映射到相同的schedule(instruction point).
每个p/q就是一个instruction point.
能不能直接用这个语义证validator?还是证一下等价性,然后在s2sloop的工作上证比较好?==> TODO: 回顾下validate algo.
*)
(**
scanned
: (nat -> list Z -> bool) -> nat -> list Z -> nat -> list Z -> bool
to_scan 看起来是记录还没被执行的instruction point?
scanned每次增加一个被执行的点?
to_scan n p 的 p,我不清楚它的含义;看起来是循环变量的值?
n用来从poly program找poly instr. n 和 p 没有关系.
*)
Instance scanned_proper : Proper ((eq ==> veq ==> eq) ==> eq ==> veq ==> (eq ==> veq ==> eq)) scanned.
Proof.
intros to_scan1 to_scan2 Hto_scan n1 n2 Hn p1 p2 Hp m1 m2 Hm q1 q2 Hq.
unfold scanned.
erewrite Hto_scan by (exact Hm || exact Hq).
rewrite Hn. rewrite Hm. rewrite Hp. rewrite Hq.
reflexivity.
Qed.
Notation "'wf_scan'" := (Proper (eq ==> veq ==> eq)) (only parsing).
Inductive poly_semantics : (nat -> list Z -> bool) -> Poly_Program -> mem -> mem -> Prop :=
| PolyDone : forall to_scan prog mem, (forall n p, to_scan n p = false) -> poly_semantics to_scan prog mem mem
| PolyProgress : forall to_scan prog mem1 mem2 mem3 poly_instr n p,
to_scan n p = true -> nth_error prog n = Some poly_instr ->
(forall n2 p2 poly_instr2, nth_error prog n2 = Some poly_instr2 ->
(** 没有另一个未被执行的instruction point比要被执行的这个小 *)
lex_compare (affine_product poly_instr2.(pi_schedule) p2) (affine_product poly_instr.(pi_schedule) p) = Lt ->
to_scan n2 p2 = false) ->
instr_semantics poly_instr.(pi_instr) (affine_product poly_instr.(pi_transformation) p) mem1 mem2 ->
poly_semantics (scanned to_scan n p) prog mem2 mem3 ->
poly_semantics to_scan prog mem1 mem3.
Print nth_error.
Theorem poly_semantics_extensionality :
forall to_scan1 prog mem1 mem2,
poly_semantics to_scan1 prog mem1 mem2 -> forall to_scan2, (forall n p, to_scan1 n p = to_scan2 n p) -> poly_semantics to_scan2 prog mem1 mem2.
Proof.
intros to_scan1 prog mem1 mem2 Hsem.
induction Hsem as [to_scan3 prog1 mem4 Hdone|to_scan3 prog1 mem3 mem4 mem5 pi n p Hscanp Heqpi Hts Hsem1 Hsem2 IH].
- intros. constructor. intros. eauto.
- intros to_scan2 Heq. econstructor; eauto. apply IH. intros. autounfold. rewrite Heq. auto.
Qed.
Lemma scanned_wf_compat :
forall to_scan n p, wf_scan to_scan -> wf_scan (scanned to_scan n p).
Proof.
intros to_scan n p Hwf. apply scanned_proper; [exact Hwf | reflexivity | reflexivity].
Qed.
Theorem poly_semantics_concat :
forall to_scan1 prog mem1 mem2,
poly_semantics to_scan1 prog mem1 mem2 ->
forall to_scan2 mem3,
wf_scan to_scan1 ->
(forall n p, to_scan1 n p = false \/ to_scan2 n p = false) ->
(forall n1 p1 pi1 n2 p2 pi2, nth_error prog n1 = Some pi1 -> nth_error prog n2 = Some pi2 ->
lex_compare (affine_product pi2.(pi_schedule) p2) (affine_product pi1.(pi_schedule) p1) = Lt ->
to_scan1 n1 p1 = false \/ to_scan2 n2 p2 = false) ->
poly_semantics to_scan2 prog mem2 mem3 ->
poly_semantics (fun n p => to_scan1 n p || to_scan2 n p) prog mem1 mem3.
Proof.
intros to_scan1 prog mem1 mem2 Hsem.
induction Hsem as [to_scan3 prog1 mem4 Hdone|to_scan3 prog1 mem4 mem5 mem6 pi n p Hscanp Heqpi Hts Hsem1 Hsem2 IH].
- intros to_scan2 mem3 Hwf1 Hdisj Hcmp Hsem1. eapply poly_semantics_extensionality; eauto. intros. rewrite Hdone. auto.
- intros to_scan2 mem3 Hwf1 Hdisj Hcmp Hsem3. eapply PolyProgress with (n := n) (p := p); eauto.
+ rewrite Hscanp. auto.
+ intros n2 p2 pi2 Heqpi2 Hts2.
reflect; split.
* apply (Hts n2 p2 pi2); auto.
* destruct (Hcmp n p pi n2 p2 pi2) as [H | H]; auto; congruence.
+ eapply poly_semantics_extensionality; [apply IH|]; eauto.
* apply scanned_wf_compat; auto.
* intros n2 p2. autounfold. destruct (Hdisj n2 p2) as [H | H]; rewrite H; auto.
* intros n1 p1 pi1 n2 p2 pi2 Heqpi1 Heqpi2 Hcmp1.
destruct (Hcmp n1 p1 pi1 n2 p2 pi2) as [H | H]; autounfold; auto. rewrite H. simpl.
destruct (is_eq p p1 && (n =? n1)%nat) eqn:Hd; simpl; auto.
* intros n0 p0. autounfold. simpl.
destruct (to_scan3 n0 p0) eqn:Hscan3; simpl; auto.
-- destruct (Hdisj n0 p0) as [H | H]; [congruence|rewrite H; auto using orb_false_r].
-- destruct (is_eq p p0 && (n =? n0)%nat) eqn:Hd; simpl; auto using andb_true_r.
reflect. destruct Hd as [Heqp Hn]. rewrite Heqp, Hn in Hscanp. congruence.
Qed.
(** * The semantics of polyhedral programs with lexicographical ordering *)
Inductive poly_lex_semantics : (nat -> list Z -> bool) -> Poly_Program -> mem -> mem -> Prop :=
| PolyLexDone : forall to_scan prog mem, (forall n p, to_scan n p = false) -> poly_lex_semantics to_scan prog mem mem
| PolyLexProgress : forall to_scan prog mem1 mem2 mem3 poly_instr n p,
to_scan n p = true -> nth_error prog n = Some poly_instr ->
(forall n2 p2, lex_compare p2 p = Lt -> to_scan n2 p2 = false) ->
instr_semantics poly_instr.(pi_instr) (affine_product poly_instr.(pi_transformation) p) mem1 mem2 ->
poly_lex_semantics (scanned to_scan n p) prog mem2 mem3 ->
poly_lex_semantics to_scan prog mem1 mem3.
Theorem poly_lex_semantics_extensionality :
forall to_scan1 prog mem1 mem2,
poly_lex_semantics to_scan1 prog mem1 mem2 -> forall to_scan2, (forall n p, to_scan1 n p = to_scan2 n p) -> poly_lex_semantics to_scan2 prog mem1 mem2.
Proof.
intros to_scan1 prog mem1 mem2 Hsem.
induction Hsem as [to_scan3 prog1 mem4 Hdone|to_scan3 prog1 mem3 mem4 mem5 pi n p Hscanp Heqpi Hts Hsem1 Hsem2 IH].
- intros. constructor. intros. eauto.
- intros to_scan2 Heq. econstructor; eauto. apply IH. intros. autounfold. rewrite Heq. auto.
Qed.
Lemma poly_lex_semantics_pis_ext_single :
forall pis1 pis2 to_scan mem1 mem2,
Forall2 (fun pi1 pi2 => pi1.(pi_instr) = pi2.(pi_instr) /\ pi1.(pi_transformation) = pi2.(pi_transformation)) pis1 pis2 ->
poly_lex_semantics to_scan pis1 mem1 mem2 -> poly_lex_semantics to_scan pis2 mem1 mem2.
Proof.
intros pis1 pis2 to_scan mem1 mem2 Hsame Hsem.
induction Hsem as [to_scan1 prog mem Hdone|to_scan1 prog mem1 mem2 mem3 pi n p Hscanp Heqpi Hts Hsem1 Hsem2 IH].
- apply PolyLexDone; auto.
- destruct (Forall2_nth_error _ _ _ _ _ _ _ Hsame Heqpi) as [pi2 [Hpi2 [H1 H2]]].
eapply PolyLexProgress; [exact Hscanp|exact Hpi2|exact Hts| |apply IH; auto].
rewrite H1, H2 in *; auto.
Qed.
Lemma poly_lex_semantics_pis_ext_iff :
forall pis1 pis2 to_scan mem1 mem2,
Forall2 (fun pi1 pi2 => pi1.(pi_instr) = pi2.(pi_instr) /\ pi1.(pi_transformation) = pi2.(pi_transformation)) pis1 pis2 ->
poly_lex_semantics to_scan pis1 mem1 mem2 <-> poly_lex_semantics to_scan pis2 mem1 mem2.
Proof.
intros pis1 pis2 to_scan mem1 mem2 Hsame.
split.
- apply poly_lex_semantics_pis_ext_single; auto.
- apply poly_lex_semantics_pis_ext_single.
eapply Forall2_imp; [|apply Forall2_sym; exact Hsame].
intros x y H; simpl in *; destruct H; auto.
Qed.
Lemma poly_lex_semantics_ext_iff :
forall pis to_scan1 to_scan2 mem1 mem2,
(forall n p, to_scan1 n p = to_scan2 n p) ->
poly_lex_semantics to_scan1 pis mem1 mem2 <-> poly_lex_semantics to_scan2 pis mem1 mem2.
Proof.
intros pis to_scan1 to_scan2 mem1 mem2 Hsame.
split; intros H.
- eapply poly_lex_semantics_extensionality; [exact H|]. auto.
- eapply poly_lex_semantics_extensionality; [exact H|]. auto.
Qed.
Theorem poly_lex_concat :
forall to_scan1 prog mem1 mem2,
poly_lex_semantics to_scan1 prog mem1 mem2 ->
forall to_scan2 mem3,
wf_scan to_scan1 ->
(forall n p, to_scan1 n p = false \/ to_scan2 n p = false) ->
(forall n1 p1 n2 p2, lex_compare p2 p1 = Lt -> to_scan1 n1 p1 = false \/ to_scan2 n2 p2 = false) ->
poly_lex_semantics to_scan2 prog mem2 mem3 ->
poly_lex_semantics (fun n p => to_scan1 n p || to_scan2 n p) prog mem1 mem3.
Proof.
intros to_scan1 prog mem1 mem2 Hsem.
induction Hsem as [to_scan3 prog1 mem4 Hdone|to_scan3 prog1 mem4 mem5 mem6 pi n p Hscanp Heqpi Hts Hsem1 Hsem2 IH].
- intros to_scan2 mem3 Hwf1 Hdisj Hcmp Hsem1. eapply poly_lex_semantics_extensionality; eauto. intros. rewrite Hdone. auto.
- intros to_scan2 mem3 Hwf1 Hdisj Hcmp Hsem3. eapply PolyLexProgress with (n := n) (p := p); eauto.
+ rewrite Hscanp. auto.
+ intros n2 p2 Hts2.
reflect. split.
* apply (Hts n2 p2); auto.
* destruct (Hcmp n p n2 p2) as [H | H]; auto; congruence.
+ eapply poly_lex_semantics_extensionality; [apply IH|]; eauto.
* apply scanned_wf_compat; auto.
* intros n2 p2. autounfold. destruct (Hdisj n2 p2) as [H | H]; rewrite H; auto.
* intros n1 p1 n2 p2 Hcmp1.
destruct (Hcmp n1 p1 n2 p2) as [H | H]; autounfold; auto. rewrite H. simpl.
destruct (is_eq p p1 && (n =? n1)%nat) eqn:Hd; simpl; auto.
* intros n0 p0. autounfold. simpl.
destruct (to_scan3 n0 p0) eqn:Hscan3; simpl; auto.
-- destruct (Hdisj n0 p0) as [H | H]; [congruence|rewrite H; auto using orb_false_r].
-- destruct (is_eq p p0 && (n =? n0)%nat) eqn:Hd; simpl; auto using andb_true_r.
reflect. destruct Hd as [Heqp Hn]. rewrite Heqp, Hn in Hscanp. congruence.
Qed.
Theorem poly_lex_concat_seq :
forall A to_scans (l : list A) prog mem1 mem2,
iter_semantics (fun x => poly_lex_semantics (to_scans x) prog) l mem1 mem2 ->
(forall x, wf_scan (to_scans x)) ->
(forall x1 k1 x2 k2 n p, to_scans x1 n p = true -> to_scans x2 n p = true -> nth_error l k1 = Some x1 -> nth_error l k2 = Some x2 -> k1 = k2) ->
(forall x1 n1 p1 k1 x2 n2 p2 k2, lex_compare p2 p1 = Lt -> to_scans x1 n1 p1 = true -> to_scans x2 n2 p2 = true -> nth_error l k1 = Some x1 -> nth_error l k2 = Some x2 -> (k2 <= k1)%nat) ->
poly_lex_semantics (fun n p => existsb (fun x => to_scans x n p) l) prog mem1 mem2.
Proof.
intros A to_scans l1 prog mem1 mem3 Hsem.
induction Hsem as [mem|x l mem1 mem2 mem3 Hsem1 Hsem2 IH].
- intros Hwf Hscans Hcmp.
simpl.
apply PolyLexDone; auto.
- intros Hwf Hscans Hcmp.
eapply poly_lex_semantics_extensionality.
+ eapply poly_lex_concat; [exact Hsem1| | | |apply IH; auto].
* apply Hwf.
* intros n p. simpl.
destruct (to_scans x n p) eqn:Hscanl; [|auto]. right.
apply not_true_is_false; rewrite existsb_exists; intros [x1 [Hin Hscanx1]].
apply In_nth_error in Hin; destruct Hin as [u Hu].
specialize (Hscans x O x1 (S u) n p Hscanl Hscanx1).
simpl in Hscans. intuition congruence.
* intros n1 p1 n2 p2 H.
destruct (to_scans x n1 p1) eqn:Hscanl; [|auto]. right.
apply not_true_is_false; rewrite existsb_exists; intros [x1 [Hin Hscanx1]].
apply In_nth_error in Hin; destruct Hin as [u Hu].
specialize (Hcmp x n1 p1 O x1 n2 p2 (S u) H Hscanl Hscanx1).
intuition lia.
* intros x1 k1 x2 k2 n p H1 H2 H3 H4; specialize (Hscans x1 (S k1) x2 (S k2) n p).
intuition congruence.
* intros x1 n1 p1 k1 x2 n2 p2 k2 H1 H2 H3 H4 H5; specialize (Hcmp x1 n1 p1 (S k1) x2 n2 p2 (S k2)).
intuition lia.
+ intros n p. simpl. reflexivity.
Qed.
(** * Translating a program from explicit scheduling to lexicographical scanning *)
Definition insert_zeros (d : nat) (i : nat) (l : list Z) := resize i l ++ repeat 0 d ++ skipn i l.
Definition insert_zeros_constraint (d : nat) (i : nat) (c : list Z * Z) := (insert_zeros d i (fst c), snd c).
(** [make_null_poly d n] creates a polyhedron with the constraints that the variables from [d] to [d+n-1] are null *)
Fixpoint make_null_poly (d : nat) (n : nat) :=
match n with
| O => nil
| S n => (repeat 0 d ++ (-1 :: nil), 0) :: (repeat 0 d ++ (1 :: nil), 0) :: make_null_poly (S d) n
end.
(** [make_sched_poly d i env_size l] adds the lexicographical constraints in [l] as equalities, preserving the [env_size] first variables,
and inserting [d] variables after that. *)
Fixpoint make_sched_poly (d : nat) (i : nat) (env_size : nat) (l : list (list Z * Z)) :=
(* add scheduling constraints in polyhedron after env, so that with fixed env, lexicographical ordering preserves semantics *)
match l with
| nil => make_null_poly (i + env_size)%nat (d - i)%nat
| (v, c) :: l =>
let vpref := resize env_size v in
let vsuf := skipn env_size v in
(vpref ++ repeat 0 i ++ (-1 :: repeat 0 (d - i - 1)%nat) ++ vsuf, -c)
:: (mult_vector (-1) vpref ++ repeat 0 i ++ (1 :: repeat 0 (d - i - 1)%nat) ++ (mult_vector (-1) vsuf), c)
:: make_sched_poly d (S i) env_size l
end.
Theorem make_null_poly_correct :
forall n d p q r, length p = d -> length q = n -> in_poly (p ++ q ++ r) (make_null_poly d n) = is_null q.
Proof.
induction n.
- intros; destruct q; simpl in *; auto; lia.
- intros d p q r Hlp Hlq.
destruct q as [|x q]; simpl in *; [lia|].
unfold satisfies_constraint; simpl.
repeat (rewrite dot_product_app; [|rewrite repeat_length; lia]; simpl).
autorewrite with vector.
assert (He : p ++ x :: q ++ r = (p ++ (x :: nil)) ++ q ++ r).
{ rewrite <- app_assoc; auto. }
rewrite He. rewrite IHn; [|rewrite app_length; simpl; lia|lia].
rewrite andb_assoc. f_equal.
destruct (x =? 0) eqn:Hx; reflect; lia.
Qed.
Theorem make_sched_poly_correct_aux :
forall l i d es, (length l <= d - i)%nat ->
forall p q r s, length p = es -> length q = i -> length r = (d - i)%nat ->
in_poly (p ++ q ++ r ++ s) (make_sched_poly d i es l) = is_eq r (affine_product l (p ++ s)).
Proof.
induction l.
- intros. simpl in *. rewrite is_eq_nil_right. rewrite app_assoc. apply make_null_poly_correct; auto. rewrite app_length; lia.
- intros i d es Hlength p q r s Hlp Hlq Hlr.
simpl in *. destruct a as [v c]. simpl in *.
destruct r as [|x r]; simpl in *; [lia|].
unfold satisfies_constraint; simpl.
repeat (rewrite dot_product_app; [|rewrite ?repeat_length, ?mult_vector_length, ?resize_length; lia]; simpl).
autorewrite with vector.
assert (He : p ++ q ++ x :: r ++ s = p ++ (q ++ (x :: nil)) ++ r ++ s).
{ rewrite <- app_assoc. auto. }
rewrite He. rewrite IHl; [|lia|auto|rewrite app_length;simpl;lia|lia].
rewrite andb_assoc. f_equal.
assert (Hde : dot_product v (p ++ s) = dot_product p (resize es v) + dot_product s (skipn es v)).
{ rewrite <- dot_product_app by (rewrite resize_length; lia).
rewrite dot_product_commutative. rewrite resize_skipn_eq. reflexivity.
}
destruct (x =? dot_product v (p ++ s) + c) eqn:Hx; reflect; lia.
Qed.
Theorem make_sched_poly_correct :
forall l d es, (length l <= d)%nat ->
forall p q r, length p = es -> length q = d ->
in_poly (p ++ q ++ r) (make_sched_poly d 0%nat es l) = is_eq q (affine_product l (p ++ r)).
Proof.
intros. apply make_sched_poly_correct_aux with (q := nil); auto; lia.
Qed.
Theorem make_null_poly_nrl :
forall n d, (poly_nrl (make_null_poly d n) <= d + n)%nat.
Proof.
induction n.
- intros; unfold poly_nrl; simpl; lia.
- intros d. simpl. unfold poly_nrl; simpl.
rewrite !Nat.max_lub_iff.
split; [|split; [|specialize (IHn (S d)); unfold poly_nrl in *; lia]];
rewrite <- nrlength_def, resize_app_le, repeat_length by (rewrite repeat_length; lia);
replace (d + S n - d)%nat with (S n) by lia; simpl;
f_equiv; f_equiv; rewrite resize_eq; simpl; (reflexivity || lia).
Qed.
Theorem make_sched_poly_nrl_aux :
forall l i d es, (length l + i <= d)%nat -> (poly_nrl (make_sched_poly d i es l) <= d + (Nat.max es (poly_nrl l)))%nat.
Proof.
induction l.
- simpl; intros i d es H.
generalize (make_null_poly_nrl (d - i)%nat (i + es)%nat). lia.
- intros i d es H; simpl in *. destruct a as [a c]. unfold poly_nrl in *; simpl in *.
rewrite !Nat.max_lub_iff. split; [|split; [|rewrite IHl; lia]].
all: rewrite nrlength_app; transitivity (es + (i + S ((d - i - 1) + (nrlength a - es))))%nat; [|lia].
all: rewrite ?mult_vector_length, resize_length; apply Nat.add_le_mono_l.
all: rewrite nrlength_app, repeat_length; apply Nat.add_le_mono_l.
all: rewrite nrlength_cons; apply -> Nat.succ_le_mono.
all: rewrite nrlength_app, repeat_length; apply Nat.add_le_mono_l.
all: rewrite ?nrlength_mult, nrlength_skipn; lia.
Qed.
Theorem make_sched_poly_nrl :
forall l d es, (length l <= d)%nat -> (poly_nrl (make_sched_poly d 0%nat es l) <= d + (Nat.max es (poly_nrl l)))%nat.
Proof.
intros; apply make_sched_poly_nrl_aux; lia.
Qed.
Lemma insert_zeros_nrl :
forall d es v, (nrlength (insert_zeros d es v) <= d + nrlength v)%nat.
Proof.
induction es.
- intros v; unfold insert_zeros; simpl. rewrite nrlength_app, repeat_length; lia.
- intros [|x v]; unfold insert_zeros in *; simpl.
+ case_if eq H; reflect; [lia|].
exfalso; apply H. apply nrlength_null_zero.
unfold is_null. rewrite !forallb_app; reflect.
split; [apply resize_nil_null|]. split; [apply repeat_zero_is_null|auto].
+ case_if eq H1; reflect; [lia|].
case_if eq H2; reflect.
* destruct H2 as [-> H2]; apply nrlength_zero_null in H2. destruct H1 as [H1 | H1]; [lia|].
exfalso; apply H1. apply nrlength_null_zero.
rewrite resize_null_repeat by auto.
unfold is_null; rewrite !forallb_app; reflect.
split; [apply repeat_zero_is_null|]. split; [apply repeat_zero_is_null|].
apply nrlength_zero_null; apply nrlength_null_zero in H2.
rewrite nrlength_skipn; lia.
* specialize (IHes v). lia.
Qed.
Definition pi_elim_schedule (d : nat) (env_size : nat) (pi : Polyhedral_Instruction) :=
{|
pi_instr := pi.(pi_instr) ;
pi_schedule := nil ;
pi_transformation := map (insert_zeros_constraint d env_size) pi.(pi_transformation) ;
pi_poly := make_sched_poly d 0%nat env_size pi.(pi_schedule) ++
map (insert_zeros_constraint d env_size) pi.(pi_poly) ;
|}.
Lemma pi_elim_schedule_nrl :
forall d es pi,
(length pi.(pi_schedule) <= d)%nat ->
(poly_nrl (pi_elim_schedule d es pi).(pi_poly) <= d + (Nat.max es (Nat.max (poly_nrl pi.(pi_poly)) (poly_nrl pi.(pi_schedule)))))%nat.
Proof.
intros d es pi H. simpl.
rewrite poly_nrl_app. rewrite Nat.max_lub_iff; split.
- rewrite make_sched_poly_nrl; lia.
- unfold poly_nrl, insert_zeros_constraint in *. rewrite map_map. apply list_le_max; intros u Hu.
rewrite in_map_iff in Hu. destruct Hu as [c [Hu Hc]]; simpl in *.
transitivity (d + nrlength (fst c))%nat;
[|apply Nat.add_le_mono_l; rewrite !Nat.max_le_iff; right; left; apply list_max_ge; rewrite in_map_iff; exists c; auto].
rewrite <- Hu; apply insert_zeros_nrl.
Qed.
Definition elim_schedule (d : nat) (env_size : nat) (p : Poly_Program) := map (pi_elim_schedule d env_size) p.
Lemma split3_eq :
forall i d l, resize i l ++ resize d (skipn i l) ++ skipn (d + i)%nat l =v= l.
Proof.
intros.
rewrite <- is_eq_veq.
rewrite is_eq_app_left. autorewrite with vector. rewrite is_eq_reflexive. simpl.
rewrite is_eq_app_left. autorewrite with vector. rewrite is_eq_reflexive. simpl.
rewrite skipn_skipn. apply is_eq_reflexive.
Qed.
Lemma insert_zeros_product_skipn :
forall d i l1 l2, dot_product (insert_zeros d i l1) l2 = dot_product l1 (resize i l2 ++ skipn (d + i)%nat l2).
Proof.
intros.
unfold insert_zeros.
rewrite !dot_product_app_left, dot_product_app_right.
autorewrite with vector. rewrite repeat_length.
rewrite skipn_skipn. lia.
Qed.
Lemma affine_product_skipn :
forall d i m l, affine_product (map (insert_zeros_constraint d i) m) l = affine_product m (resize i l ++ skipn (d + i)%nat l).
Proof.
intros. unfold affine_product. rewrite map_map.
apply map_ext. intros.
unfold insert_zeros_constraint; simpl.
rewrite insert_zeros_product_skipn. auto.
Qed.
(** * Schedule elimination is correct *)
Theorem poly_elim_schedule_semantics_preserve :
forall d es env to_scan_lex prog_lex mem1 mem2,
poly_lex_semantics to_scan_lex prog_lex mem1 mem2 ->
forall to_scan prog,
prog_lex = elim_schedule d es prog ->
wf_scan to_scan -> wf_scan to_scan_lex ->
(forall n pi, nth_error prog n = Some pi -> (length pi.(pi_schedule) <= d)%nat) ->
(forall n p q ts pi, nth_error prog n = Some pi -> length p = es -> length ts = d ->
to_scan_lex n (p ++ ts ++ q) = is_eq ts (affine_product pi.(pi_schedule) (p ++ q)) && to_scan n (p ++ q)) ->
(forall n p q, length p = es -> to_scan n (p ++ q) = true -> p =v= env) ->
(forall n p, nth_error prog n = None -> to_scan n p = false) ->
poly_semantics to_scan prog mem1 mem2.
Proof.
intros d es env to_scan_lex prog mem1 mem2 Hsem.
induction Hsem as [to_scan_l1 prog_l1 mem3 Hdone|to_scan_l1 prog_l1 mem3 mem4 mem5 pi n p Hscanp Heqpi Hts Hsem1 Hsem2 IH].
- intros to_scan prog Hprogeq Hwf Hwflex Hsched_length Hcompat Hscanenv Hout.
apply PolyDone. intros n p.
destruct (nth_error prog n) as [pi|] eqn:Heq.
+ specialize (Hcompat n (resize es p) (skipn es p) (resize d (affine_product pi.(pi_schedule) p)) pi).
rewrite Hdone in Hcompat.
rewrite resize_skipn_eq in Hcompat. rewrite resize_eq in Hcompat.
* simpl in Hcompat. symmetry; apply Hcompat; auto.
* unfold affine_product. rewrite map_length. eauto.
+ auto.
- intros to_scan prog Hprogeq Hwf Hwflex Hsched_length Hcompat Hscanenv Hout.
rewrite <- split3_eq with (d := d) (i := es) in Hscanp.
rewrite Hprogeq in *; unfold elim_schedule in Heqpi.
destruct (nth_error prog n) as [pi1|] eqn:Hpi1; [| rewrite map_nth_error_none in Heqpi; congruence ].
erewrite map_nth_error in Heqpi; eauto; inversion Heqpi as [Heqpi1].
rewrite Hcompat with (pi := pi1) in Hscanp; auto.
reflect; destruct Hscanp as [Heqp Hscan].
eapply PolyProgress with (n := n) (p := resize es p ++ skipn (d + es)%nat p); eauto.
+ intros n2 p2 pi2 Heqpi2 Hcmp.
specialize (Hts n2 (resize es p2 ++ (resize d (affine_product pi2.(pi_schedule) p2)) ++ skipn es p2)).
rewrite Hcompat with (pi := pi2) in Hts; auto.
rewrite resize_skipn_eq in Hts.
rewrite resize_eq in Hts by (unfold affine_product; rewrite map_length; eauto). simpl in Hts.
destruct (to_scan n2 p2) eqn:Hscan2; auto. apply Hts.
rewrite <- split3_eq with (l := p) (d := d) (i := es).
rewrite !lex_compare_app by (rewrite !resize_length; reflexivity).
rewrite Hscanenv with (p := resize es p2) by (apply resize_length || rewrite resize_skipn_eq; apply Hscan2).
rewrite Hscanenv with (p := resize es p) by (apply resize_length || apply Hscan).
rewrite lex_compare_reflexive. simpl.
rewrite Heqp. rewrite resize_eq by (unfold affine_product; rewrite map_length; eauto).
rewrite Hcmp. reflexivity.
+ rewrite <- Heqpi1 in Hsem1; unfold pi_elim_schedule in Hsem1; simpl in *.
rewrite affine_product_skipn in Hsem1. apply Hsem1.
+ apply IH; auto.
* apply scanned_wf_compat; auto.
* apply scanned_wf_compat; auto.
* intros n0 p0 q0 ts pi0 Hpi0 Hlp0 Hlts.
unfold scanned. rewrite Hcompat with (pi := pi0); auto.
destruct (is_eq ts (affine_product (pi_schedule pi0) (p0 ++ q0))) eqn:Htseq; auto.
simpl.
f_equal; f_equal.
destruct (n =? n0)%nat eqn:Heqn; [|rewrite !andb_false_r; auto]. rewrite !andb_true_r.
rewrite <- split3_eq with (l := p) (d := d) (i := es) at 1.
rewrite !is_eq_app by (rewrite resize_length; auto).
destruct (is_eq (resize es p) p0) eqn:Heqp0; simpl; auto.
destruct (is_eq (skipn (d + es)%nat p) q0) eqn:Heqq0; simpl; auto using andb_false_r.
rewrite andb_true_r.
reflect. rewrite Heqn in *.
assert (Heqpi0 : pi0 = pi1) by congruence. rewrite Heqpi0 in *.
rewrite Heqp. rewrite Htseq. f_equal.
assert (H : p0 ++ q0 =v= resize es p ++ skipn (d + es) p); [|rewrite H; reflexivity].
rewrite <- is_eq_veq. rewrite is_eq_app by (rewrite resize_length; auto).
reflect; split; symmetry; assumption.
* intros n0 p0 q0 H. unfold scanned. reflect. intros [H1 H2]. eapply Hscanenv; eauto.
* intros n0 p0 H. unfold scanned. rewrite Hout; auto.
Qed.
(** * Semantics in a fixed environment *)
(* Print resize. *)
(**
resize n l: len(l) > n 时补充0,len(l) < n 时截断.
is_eq: 不关心 trailing zeros.
*)
(** 问:env一般是更长还是更短?这里的env到底是什么?看起来应该是context, constant *)
Definition env_scan (prog : Poly_Program) (env : list Z) (dim : nat) (n : nat) (p : list Z) :=
match nth_error prog n with
| Some pi => is_eq env (resize (length env) p) && is_eq p (resize dim p) && in_poly p pi.(pi_poly)
(** 1. 这里应该是将p截断到env的长度,也就是要求p的env的部分和env(context)一致;
2. len(p) >= dim
3. p的确是domain中的一个点
*)
| None => false
end.
(** 配合scanned使用 *)
(* Print scanned. *)
Definition env_poly_semantics (env : list Z) (dim : nat) (prog : Poly_Program) (mem1 mem2 : mem) :=
poly_semantics (env_scan prog env dim) prog mem1 mem2.
Definition env_poly_lex_semantics (env : list Z) (dim : nat) (prog : Poly_Program) (mem1 mem2 : mem) :=
poly_lex_semantics (env_scan prog env dim) prog mem1 mem2.
Instance env_scan_proper : forall prog env dim, Proper (eq ==> veq ==> eq) (env_scan prog env dim).
Proof.
intros prog env dim n1 n2 Hn p1 p2 Hp. rewrite Hn. unfold env_scan.
destruct (nth_error prog n2) as [pi|]; simpl; auto.
rewrite Hp at 1 2 4; rewrite Hp at 1. reflexivity.
Qed.
(** * Schedule elimination in a fixed environment is correct as well *)
Theorem poly_elim_schedule_semantics_env_preserve :
forall d es env dim prog mem1 mem2,
es = length env ->
(es <= dim)%nat ->
env_poly_lex_semantics env (dim + d) (elim_schedule d es prog) mem1 mem2 ->
(forall n pi, nth_error prog n = Some pi -> (length pi.(pi_schedule) <= d)%nat) ->
env_poly_semantics env dim prog mem1 mem2.
Proof.
intros d es env dim prog mem1 mem2 Hlength Hdim Hsem Hsched_length.
unfold env_poly_semantics. unfold env_poly_lex_semantics in Hsem.
eapply poly_elim_schedule_semantics_preserve.
- exact Hsem.
- reflexivity.
- apply env_scan_proper.
- apply env_scan_proper.
- exact Hsched_length.
- intros n p q ts pi Heqpi Hlp Hlts.
unfold env_scan. unfold elim_schedule. rewrite map_nth_error with (d := pi); auto. rewrite Heqpi.
rewrite <- Hlength. unfold pi_elim_schedule; simpl.
rewrite !resize_app with (n := es) by apply Hlp.
destruct (is_eq env p); simpl; auto using andb_false_r.
rewrite in_poly_app. rewrite andb_comm. rewrite <- andb_assoc. f_equal.
+ apply make_sched_poly_correct; eauto.
+ rewrite andb_comm. f_equal.
* rewrite !resize_app_le by lia.
rewrite !is_eq_app by lia. rewrite !is_eq_reflexive. simpl.
f_equal. f_equal. lia.
* unfold in_poly. rewrite forallb_map. apply forallb_ext.
intros c. unfold satisfies_constraint. unfold insert_zeros_constraint. simpl.
f_equal. rewrite dot_product_commutative. rewrite insert_zeros_product_skipn.
rewrite resize_app by apply Hlp.
rewrite app_assoc. rewrite skipn_app; [|rewrite app_length; lia].
apply dot_product_commutative.
- intros n p q Hlp Hscan. unfold env_scan in Hscan.
destruct (nth_error prog n) as [pi|]; [|congruence].
reflect. destruct Hscan as [[He Hr] Hp].
rewrite resize_app in He by congruence. symmetry. exact He.
- intros n p Hout. unfold env_scan. rewrite Hout. auto.
Qed.