Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Prerequisite for AutoScheduler #362

Merged
merged 57 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
6782d42
gemm
hikettei Dec 20, 2024
44a490f
Tile
hikettei Dec 20, 2024
610316a
search tile
hikettei Dec 20, 2024
0cd518a
Merge branch 'main' into opt-gemm
hikettei Dec 20, 2024
f6bcc57
OpenBLAS
hikettei Dec 20, 2024
253cf6f
90GFLops
hikettei Dec 20, 2024
d4c08c4
wip
hikettei Dec 20, 2024
db52790
wip
hikettei Dec 20, 2024
082ae9d
wip
hikettei Dec 20, 2024
9446987
tile param search
hikettei Dec 20, 2024
6a92698
finish an experiment
hikettei Dec 20, 2024
0e435c1
apply tiling without causing segv
hikettei Dec 20, 2024
e32cec2
tiling all bands
hikettei Dec 20, 2024
c4215bc
things are moved
hikettei Dec 20, 2024
4c2de70
move everything into auto-scheduler
hikettei Dec 21, 2024
785d2eb
asd
hikettei Dec 21, 2024
34656d2
fix: tiling-size related gc error
hikettei Dec 21, 2024
6a0c211
An initial attempt to unroll: tile based
hikettei Dec 21, 2024
0ab186d
moved
hikettei Dec 21, 2024
2ce3075
split
hikettei Dec 21, 2024
b22d20f
.
hikettei Dec 21, 2024
874cd17
moved isl things to auto-scheduler
hikettei Dec 21, 2024
b47fd09
refactor
hikettei Dec 21, 2024
e92725d
synchronize baseline
hikettei Dec 21, 2024
e1e5dc7
updt
hikettei Dec 21, 2024
dfca38d
updt
hikettei Dec 21, 2024
0ed2740
typo
hikettei Dec 21, 2024
0adb742
typo
hikettei Dec 21, 2024
378c310
typo
hikettei Dec 21, 2024
d2de5a1
Tweak
hikettei Dec 21, 2024
cd01fe9
nested unroll directive
hikettei Dec 21, 2024
64b8c7c
Unroll for static and simple case
hikettei Dec 21, 2024
3ac9213
Fold MOD
hikettei Dec 21, 2024
0d4095f
nope
hikettei Dec 21, 2024
642acc9
Initial attempt of unrolling
hikettei Dec 21, 2024
ea829fd
copy and remove UNROLL_PARENT for the reminder
hikettei Dec 21, 2024
b7ba19f
Loop Unrolling is worked
hikettei Dec 21, 2024
6df4461
Unrolling softmax
hikettei Dec 21, 2024
994a615
wip: generating a sketch
hikettei Dec 21, 2024
ca032f8
Identify the sketch
hikettei Dec 21, 2024
4cdec95
TODO
hikettei Dec 21, 2024
1c5096e
remove extra unroll reminder
hikettei Dec 21, 2024
bfb2d11
coincident for elwise
hikettei Dec 22, 2024
771fa79
fix: unroll
hikettei Dec 22, 2024
aad075a
quickload
hikettei Dec 22, 2024
1b56342
MemoryPlanner should produce the same result
hikettei Dec 22, 2024
e77703a
Fix for unroll reminder
hikettei Dec 22, 2024
0e9fff8
try unrolling tiled bands
hikettei Dec 22, 2024
6674d7b
ignore tiled band dims
hikettei Dec 22, 2024
ce60a52
base-item
hikettei Dec 22, 2024
5ba4fa3
remove: caten/polyhedral
hikettei Dec 22, 2024
26a4a66
clean up
hikettei Dec 22, 2024
e63ac07
rem dep
hikettei Dec 22, 2024
546d0c4
caten instead
hikettei Dec 22, 2024
f73c49f
updt
hikettei Dec 22, 2024
04cb131
updt
hikettei Dec 22, 2024
c832f5e
updt
hikettei Dec 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
things are moved
  • Loading branch information
hikettei committed Dec 20, 2024
commit c4215bc96cd32ad284581adda7e80d9650498a33
8 changes: 8 additions & 0 deletions source/codegen/auto-scheduler/auto-scheduler.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
(defpackage :caten/codegen/auto-scheduler
(:use :cl))

(in-package :caten/codegen/auto-scheduler)

(defun auto-schedule ()

)
88 changes: 88 additions & 0 deletions source/codegen/auto-scheduler/coincidence.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
(defpackage :caten/codegen/coincidence
(:use :cl :caten/codegen/polyhedral))

(in-package :caten/codegen/coincidence)

(defun get-zeros-on-union-set (delta-uset)
(declare (type isl::union-set delta-uset))
(let* ((delta-set (isl:set-from-union-set delta-uset))
(ma (isl:multi-aff-zero (isl:set-get-space delta-set))))
(isl:union-set-from-set (isl:set-from-multi-aff ma))))

(defun check-legality-parallel (node dep)
"
```
(check-legality-parallel node dep)
```
Returns T if the band node is legal to be parallelized with respect to the dep.
Reference: https://github.com/hikettei/tadashi/blob/main/src/legality.c#L91-L122"
(declare (type isl::schedule-node node) (type isl::union-map dep))
(when (isl:union-map-is-empty dep) (return-from check-legality-parallel t))
(let* ((map (isl:schedule-node-band-get-partial-schedule-union-map node))
(domain (isl:union-map-apply-range (isl:union-map-apply-domain dep map) map))
(delta (isl:union-map-deltas domain))
(_ (when (isl:union-set-is-empty delta) (return-from check-legality-parallel t)))
(zeros (get-zeros-on-union-set delta))
(cmp (isl:union-set-lex-lt-union-set delta zeros))
(retval (isl:union-set-is-empty cmp))
(cmp (isl:union-set-lex-gt-union-set delta zeros)))
(declare (ignore _))
(and retval (isl:union-set-is-empty cmp))))

(defun check-legality (schedule dep)
"
```
(check-legality schedule dep)
```
Returns T if the current schedule does not break any dependences in dep."
(declare (type isl::schedule schedule) (type isl::union-map dep))
(when (isl:union-map-is-empty dep) (return-from check-legality t))
(let* ((map (isl:schedule-get-map schedule))
(domain (isl:union-map-apply-domain dep map))
(domain (isl:union-map-apply-range domain map))
(delta (isl:union-map-deltas domain))
(zeros (get-zeros-on-union-set delta))
(le (isl:union-set-lex-le-union-set delta zeros))
(retval (isl:union-set-is-empty le)))
retval))

(defun insert-parallel (band)
(isl:schedule-node-insert-mark band (isl::make-id-from-str "parallel")))

(defun get-coincident-points (poly)
(map-schedule-nodes
#'(lambda (type band)
(when (eql type :schedule-node-band)
(when (check-legality-parallel band (poly-dependencies poly)) band)))
poly))

(defun polyir-set-coincident (poly level)
(declare (type Polyhedral-IR poly))
;; TODO(hikettei) there should be more clever way to do this:
(let ((insertable-points (get-coincident-points poly)))
(dotimes (i (length insertable-points))
(when (<= i level)
(setf (poly-schedule poly) (isl:schedule-node-get-schedule (insert-parallel (nth i (get-coincident-points poly)))))))
(length insertable-points)))

(defun %loop-interchange (schedule-node)
(declare (type isl::schedule-node schedule-node))
(let* ((mupa (isl:schedule-node-band-get-partial-schedule schedule-node))
(node (isl:schedule-node-delete schedule-node))
(n-child (isl::%isl-schedule-node-n-children (isl::schedule-node-handle node)))
(_ (when (= 0 n-child) (return-from %loop-interchange nil)))
(node (isl:schedule-node-first-child node))
(__ (when (find (isl:schedule-node-get-type node) `(:schedule-node-filter)) (return-from %loop-interchange nil)))
(node (isl:schedule-node-insert-partial-schedule node mupa)))
(declare (ignore _ __))
node))

(defun polyir-loop-interchange (poly nth)
(declare (type polyhedral-ir poly) (type fixnum nth))
(let ((bands (map-schedule-nodes #'(lambda (type band) (when (eql type :schedule-node-band) band)) poly)))
(unless (<= nth (length bands)) (return-from polyir-loop-interchange nil))
(let* ((new-sched (%loop-interchange (nth nth bands)))
(new-sched (when new-sched (isl:schedule-node-get-schedule new-sched))))
(if (and new-sched (check-legality new-sched (poly-dependencies poly)))
(progn (setf (poly-schedule poly) new-sched) t)
nil))))
168 changes: 168 additions & 0 deletions source/codegen/auto-scheduler/polyhedral.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
(defpackage :caten/codegen/polyhedral
(:import-from :cffi #:foreign-funcall)
(:shadow #:set #:space)
(:shadowing-import-from :cl :map)
(:use :cl :caten/isl)
(:export
#:make-polyhedral-ir
#:debug-render-to-clang
#:Polyhedral-IR
#:poly-schedule
#:poly-domain
#:poly-dependencies
#:map-schedule-nodes))

(in-package :caten/codegen/polyhedral)

(defclass Polyhedral-IR ()
((schedule :accessor poly-schedule)
(domain :accessor poly-domain)
(dependencies :accessor poly-dependencies)))

(defun make-polyhedral-ir (domain read write schedule)
(let ((pg (make-instance 'Polyhedral-IR)))
(setf (poly-schedule pg) schedule
(poly-domain pg) domain)
(let* ((access (union-access-info-from-sink read))
(access (union-access-info-set-must-source access write))
(access (union-access-info-set-schedule access schedule))
(flow (union-access-info-compute-flow access))
(RaW (union-flow-get-must-dependence flow))
(access (union-access-info-from-sink write))
(access (union-access-info-set-must-source access write))
(access (union-access-info-set-may-source access read))
(access (union-access-info-set-schedule access schedule))
(flow (union-access-info-compute-flow access))
(WaW (union-flow-get-must-dependence flow))
(WaR (union-flow-get-may-dependence flow))
(dependencies
(union-map-union
(union-map-union WaR RaW)
WaW)))
(setf (poly-dependencies pg) dependencies)
pg)))

(defmethod debug-render-to-clang ((pg Polyhedral-IR))
(let* ((schedule (schedule-set-options (copy (poly-schedule pg)) :separate))
(build (ast-build-from-context (set-from-str "{:}")))
(p (isl::%isl-printer-to-str (isl::context-handle isl::*context*)))
(ast (ast-build-node-from-schedule build schedule))
(p (isl::%isl-printer-set-output-format p 4)) ;; 4 == Clang
(q (isl::%isl-printer-print-ast-node p (isl::ast-node-handle ast)))
(str (isl::%isl-printer-get-str q)))
str))

(defmethod pprint-schedule ((schedule schedule))
(let ((schedule (yaml:parse (schedule-to-str schedule))))
(with-output-to-string (out)
(format out "~%")
(labels ((indent (n)
(make-string n :initial-element #\space))
(separate-screen (indent &key (n 120))
(format out "~%~a~a~%" (indent indent) (make-string n :initial-element #\-)))
(explore (schedule key &key (indent 0))
(cond
((string= key "domain")
(format out "~adomain(~%" (indent indent))
(let ((domains (cl-ppcre:split
";"
(cl-ppcre:regex-replace-all
"{|}"
(gethash key schedule)
""))))
(format out "~a"
(apply
#'concatenate
'string
(butlast
(loop for dom in domains
collect (format nil "~a~a" (indent (+ indent 2)) dom)
collect (format nil "~%"))))))
(format out "~a)" (indent indent)))
((string= key "child")
(format out "~%~achild()" (indent indent))
(separate-screen indent)
(mapc
#'(lambda (x)
(explore (gethash key schedule) x :indent (+ indent 2)))
(reverse (alexandria:hash-table-keys (gethash key schedule)))))
((string= key "schedule")
(let ((schedules (cl-ppcre:split
" , "
(cl-ppcre:regex-replace-all
"{|}"
(subseq (gethash key schedule) 1 (1- (length (gethash key schedule))))
""))))
(format out "~aschedule()" (indent indent))
(when schedules (format out "~%"))
(format out "~a"
(apply
#'concatenate
'string
(butlast
(loop for s in schedules
for nth upfrom 0
for separator = (if (= 1 (length schedules)) "-" (if (zerop nth) "┏" (if (= (length schedules) (1+ nth)) "┗" "┃")))
collect (format nil "~a ~a~a" (indent indent) separator s)
collect (format nil "~%")))))))
((or (string= key "sequence") (string= key "set"))
(format out "~a~a()" (indent indent) key)
(mapc
#'(lambda (x)
(mapc
#'(lambda (k)
(explore x k :indent (+ 2 indent)))
(alexandria:hash-table-keys x)))
(gethash key schedule)))
((string= key "filter")
(format out "~%~afilter(~%" (indent indent))
(let ((domains (cl-ppcre:split
";"
(cl-ppcre:regex-replace-all
"{|}"
(gethash key schedule)
""))))
(format
out
"~a"
(apply
#'concatenate
'string
(butlast
(loop for dom in domains
collect (format nil "~a~a" (indent (+ indent 2)) dom)
collect (format nil "~%")))))
(format out ")")))
((or (string= key "permutable") (string= key "coincident"))
(format out "~%~a~a(~a)" (indent indent) key (gethash key schedule)))
((or (string= key "mark"))
(format out "~amark(~a)" (indent indent) (gethash key schedule)))
(t (warn "pprint: the key ~a is not implemented." key)))))
(mapc #'(lambda (x) (explore schedule x)) (reverse (alexandria:hash-table-keys schedule)))))))

(defmethod print-object ((pg Polyhedral-IR) stream)
(print-unreadable-object (pg stream :type t)
(format stream "~a~%[Kernel]:~%~a" (pprint-schedule (copy (poly-schedule pg))) (debug-render-to-clang pg))))

(defun map-schedule-nodes (f polyhedral-ir)
"
```
(map-schedule-nodes f polyhedral-ir)
```
Iterates over the schedule nodes of a polyhedral-ir object. f is a lambda function which takes (type[keyword] node[schedule-node]) as an argument.
This function returns a list of the results of applying f to each node. NIL is excluded in the list."
(declare (type Polyhedral-IR polyhedral-ir) (type function f))
(let* ((node (schedule-get-root (poly-schedule polyhedral-ir)))
(next-nodes)
(outputs))
(loop named map-search
for n-children = (isl::%isl-schedule-node-n-children (isl::schedule-node-handle node))
while (>= n-children 0) do
(loop for nth upfrom 0 below n-children
for band = (schedule-node-get-child node nth)
for type = (schedule-node-get-type band) do
(let ((out (funcall f type band))) (when out (push out outputs)))
(push band next-nodes))
(when (= (length next-nodes) 0) (return-from map-search))
(setf node (pop next-nodes)))
(nreverse outputs)))
81 changes: 81 additions & 0 deletions source/codegen/auto-scheduler/tiling.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
(defpackage :caten/codegen/tiling
(:documentation "
To apply tiling to a reduce dim, use apply-tiling.
```
(apply-tile polyhedral-ir `(16 16))
```

[TODO]
- Tiling sizes can be automatically optimized by the measurer.
")
(:shadow #:set #:space)
(:shadowing-import-from :cl :map)
(:use :cl :caten/isl :caten/codegen/polyhedral))

(in-package :caten/codegen/tiling)

(defun tiling-sizes (band &key (size-default 32) (dims))
(declare (type list dims) (type fixnum size-default))
(let* ((band-space (schedule-node-band-get-space band))
(dim (space-dim band-space 3)))
(multi-val-from-val-list
band-space
(apply
#'make-value-list
(loop for i upfrom 0 below dim
collect
(or (nth i dims) size-default))))))

(defun shift-band-zero (band)
"Refernece: https://github.com/hikettei/cl-polyhedral/blob/main/source/tiling.lisp#L52C1-L79C37"
(let* ((domain (schedule-node-get-domain band))
(partial-schedule (schedule-node-band-get-partial-schedule band))
(mupa (multi-union-pw-aff-intersect-domain partial-schedule domain))
(n (multi-union-pw-aff-size mupa))
(multi-val (multi-union-pw-aff-min-multi-val mupa)))
(loop for i upfrom 0 below n
for v = (multi-val-get-val multi-val i)
do (when (value-negative-infinity-p v)
(setf multi-val (multi-val-set-val multi-val i (value 1)))))
(let* ((shift (multi-union-pw-aff-multi-val-on-domain domain multi-val))
(shift-neg (multi-union-pw-aff-neg shift))
(partial-schedule (multi-union-pw-aff-add partial-schedule shift-neg)))
(values partial-schedule shift))))

(defun tile-partial-schedule (partial-schedule tile-size &key (scale-tile-loops nil))
(let ((n (multi-union-pw-aff-size partial-schedule)))
(loop for i upfrom 0 below n
for upa1 = (multi-union-pw-aff-get-union-pw-aff partial-schedule i)
for v = (multi-val-get-val tile-size i)
for upa2 = (union-pw-aff-scale-down-val upa1 v)
for upa3 = (union-pw-aff-floor upa2)
for upa = (if scale-tile-loops (union-pw-aff-scale-val upa3 v) upa3)
do (setf partial-schedule (multi-union-pw-aff-set-union-pw-aff partial-schedule i upa)))
partial-schedule))

(defun schedule-tile-band (band &key (size-default 32) (dims))
(multiple-value-bind (partial-schedule shift)
(shift-band-zero band)
(let* ((tiling-sizes (tiling-sizes band :size-default size-default :dims dims))
(partial-schedule (tile-partial-schedule partial-schedule tiling-sizes))
(tiled-sched (multi-union-pw-aff-add partial-schedule shift)))
(schedule-node-get-schedule
(schedule-node-insert-partial-schedule band tiled-sched)))))

(defun get-tileable-bands (poly)
(map-schedule-nodes #'(lambda (type node) (when (eql type :schedule-node-band) node)) poly))
;; [TODO]
;; - the tiling sizes are 2D, optimized by measuring the computation
;; - Create a heatmap on the tiling sizes for debugging
(defun apply-tile (ir size)
"`tile-bands` helps you execute the computation tile by tile over the two axes"
(declare (type Polyhedral-IR ir))
;; [NOTE] Tile last 2d reductions
(let* ((bands (get-tileable-bands ir)))
;; Tile all bands
(dotimes (i (length bands))
(let ((i (- (1- (length bands)) i)))
(setf (poly-schedule ir)
(schedule-tile-band
(nth i (get-tileable-bands ir))
:size-default size))))))
1 change: 1 addition & 0 deletions source/polyhedral/auto-scheduler.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ An entrypoint for auto-scheduling.
;; But if (in the future) we implement more advanced Fusion Rules like Matmul+Softmax+Matmul, ConvND+Activation+Pooling
;; we may want isl scheduler to judge loop interchange?
(apply-parallelize scheduler poly)
(caten/polyhedral/tiling:tile-bands scheduler poly)
;; Unroll
;; Array Packing
;; 2D WMMA (8x8)
Expand Down
4 changes: 3 additions & 1 deletion source/polyhedral/tiling.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@
(defun tile-bands (scheduler ir)
"`tile-bands` helps you execute the computation tile by tile over the two axes"
(declare (type Polyhedral-IR ir))
;; [NOTE] Tile last 2d reductions
(let* ((bands (get-tileable-bands ir)))
(when (not (= 0 (auto-scheduler-tile-size scheduler)))
;; Tile all bands
(dotimes (i (length bands))
(let ((i (- (1- (length bands)) i)))
(setf (poly-schedule ir)
(schedule-tile-band
(nth i (get-tileable-bands ir))
:size-default (auto-scheduler-tile-size scheduler)))))))
:size-default (auto-scheduler-tile-size scheduler))))))))
Loading