Skip to content

Commit

Permalink
BEAM Search v1 (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Jan 4, 2025
1 parent 67923c0 commit 679df19
Show file tree
Hide file tree
Showing 19 changed files with 456 additions and 409 deletions.
4 changes: 2 additions & 2 deletions source/byoc/clang.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

(defclass ClangBuffer (LispBuffer) nil)
(defclass ClangRuntime (GraphRuntime) nil)
(define-auto-scheduler (Clang-Auto-Scheduler (&key (n-global-loop (1- (ctx:getenv :OMP)))))
(define-auto-scheduler (Clang-Auto-Scheduler (&key (n-global-loop (ctx:getenv :OMP))))
;; Use outermost loop parallelism for maximize memory locality (better softmax/layernorm scheduling)
:n-global-loop n-global-loop ;; OMP=1 -> The outermost loop is GLOBAL, otherwise everything is a local loop
:tile-size 32) ;; [TODO] Autotuned
:tile-sizes `(2 4 8 16 32))
(define-backend :clang ClangBuffer ClangRuntime CStyle-Renderer Clang-Auto-Scheduler t)

(defvar *indent*)
Expand Down
31 changes: 21 additions & 10 deletions source/byoc/native.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -124,23 +124,34 @@
(z (render-node renderer (nth 2 (node-reads node)))))
`(+ ,x (* ,y ,z))))

(defun extract-scop-from-loop (for)
(declare (type node for))
(assert (eql (node-type for) :FOR))
(let ((below (expr-detach-loop-bound (getattr for :below) :allow-failed t)))
(when (and below (expr-equal-to (getattr for :upfrom) 0) (expr-equal-to (getattr for :by) 1))
below)))

(defun recursive-render-bp (rest-blueprints)
(let ((bp (car rest-blueprints)))
(when (null bp) (return-from recursive-render-bp nil))
(ecase (node-type bp)
(:FOR
(let* ((endfor (position-if #'(lambda (x) (and (eql (node-type x) :ENDFOR) (equal (getattr x :idx) (getattr bp :idx)))) rest-blueprints)))
(assert endfor () "recursive-render-bp: :FOR without :ENDFOR is not allowed. Malformed blueprint?")
(when (eql (getattr bp :scope) :global)
(warn "LispStyle-Renderer: global loop is not supported yet."))
;; [TODO] Simplify the loop code and to use lparallel
;; [TODO] There is useful macro from scop
`(progn
(loop with ,(intern (getattr bp :idx)) fixnum = ,(render-expr 'LispStyle-Renderer (getattr bp :upfrom))
while ,(render-expr 'LispStyle-Renderer (getattr bp :below))
do ,(recursive-render-bp (subseq rest-blueprints 1 endfor))
(incf ,(intern (getattr bp :idx)) ,(render-expr 'LispStyle-Renderer (getattr bp :by))))
,(recursive-render-bp (subseq rest-blueprints (1+ endfor))))))
(let ((below (extract-scop-from-loop bp)))
(if below
`(progn
(,(if (eql (getattr bp :scope) :local) 'dotimes 'lparallel:pdotimes) (,(intern (getattr bp :idx)) ,(render-expr 'LispStyle-Renderer below))
,(recursive-render-bp (subseq rest-blueprints 1 endfor)))
,(recursive-render-bp (subseq rest-blueprints (1+ endfor))))
(progn
(when (eql (getattr bp :scope) :global) (warn "recursive-render-bp: The node ~a is scheduled as global but the upfrom/below/by is too complicated to handle.~%Thus this loop is not parallelized." bp))
`(progn
(loop with ,(intern (getattr bp :idx)) fixnum = ,(render-expr 'LispStyle-Renderer (getattr bp :upfrom))
while ,(render-expr 'LispStyle-Renderer (getattr bp :below))
do ,(recursive-render-bp (subseq rest-blueprints 1 endfor))
(incf ,(intern (getattr bp :idx)) ,(render-expr 'LispStyle-Renderer (getattr bp :by))))
,(recursive-render-bp (subseq rest-blueprints (1+ endfor)))))))))
(:ENDFOR
(error ":ENDFOR should not be appeared here. Malformed blueprint?"))
(:IF
Expand Down
25 changes: 12 additions & 13 deletions source/codegen/auto-scheduler/ast-parser.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ Transform the Polyhedral IR into the Blueprint IR.
```
scop.lisp for the opposite things.
")
(:use :cl :caten/codegen/expr :caten/codegen/expr-cache :caten/air :caten/codegen/shape-inference :trivia :caten/codegen/polyhedral-ast)
(:import-from :caten/codegen/unroll :mark-unroll-parent-p :mark-unroll-body-p :parse-unroll-directive)
(:use :cl :caten/codegen/expr :caten/codegen/expr-cache :caten/air :caten/codegen/shape-inference :trivia :caten/codegen/polyhedral-ast :caten/codegen/transform)
(:export #:lower-into-bp-from-polyhedral))

(in-package :caten/codegen/ast-parser)
Expand Down Expand Up @@ -50,31 +49,30 @@ scop.lisp for the opposite things.

(declaim (ftype (function (cffi:foreign-pointer) t) parse-isl-ast-mark))
(defun parse-isl-ast-mark (ast)
(let* ((mark (cffi:foreign-string-to-lisp (isl::%isl-id-get-name (isl::%isl-ast-node-mark-get-id ast))))
(let* ((directive (str->directive (cffi:foreign-string-to-lisp (isl::%isl-id-get-name (isl::%isl-ast-node-mark-get-id ast)))))
(user (parse-isl-ast (isl::%isl-ast-node-mark-get-node ast))))
(typecase user
;; Mark(Nested?)
(AstFor
(cond
((equalp mark "parallel")
((equalp (directive-type directive) "GLOBAL")
(setf (astfor-scope user) :global))
((mark-unroll-parent-p mark)
((equalp (directive-type directive) "UNROLL_OUTER")
(let ((body (astfor-body user)))
(when (or
(not (typep body 'ASTFor))
(null (and (astfor-marks body) (every #'mark-unroll-body-p (astfor-marks body)))))
(null (and (astfor-marks body) (every #'(lambda (x) (equalp (directive-type x) "UNROLL_INNER")) (astfor-marks body)))))
(return-from parse-isl-ast-mark user))
(let* ((n-unroll (parse-unroll-directive mark))
(let* ((n-unroll (directive-amount directive))
(user (copy-astfor user))
(unrolled (caten/codegen/directive:make-unrolled-body user body n-unroll))
(reminder (caten/codegen/directive:compute-reminder-for-unroll user body n-unroll)))
(setf (astfor-body user) unrolled)
(return-from parse-isl-ast-mark (make-block (list user reminder))))))
((mark-unroll-body-p mark)
((equalp (directive-type directive) "UNROLL_INNER")
;; UNROLL_BODY is triggered by the UNROLL_PARENT. Without it the form is ignored.
(assert (null (astfor-marks user)) () "UNROLL_BODY should be orthogonal with other directives.")
(setf (astfor-marks user) (list mark)))
((equalp mark "TILE_BAND")
(push "TILE_BAND" (astfor-marks user)))
(assert (null (astfor-marks user)) () "UNROLL_INNER should be orthogonal with other directives.")
(setf (astfor-marks user) (list directive)))
(T
;(warn "mark: ignored the mark ~a for ~a" mark user)
)))
Expand Down Expand Up @@ -210,7 +208,8 @@ scop.lisp for the opposite things.
(ematch object
((ASTBlock :body body) (map 'list #'lower body))
((AstFor :idx idx :from upfrom :to to :by by :body body :scope scope)
;; remove an empty loop
;; [TODO] Generalize this
;; TILE_BAND is not an unroll idx?
(let ((is-tile-band (find "TILE_BAND" (astfor-marks object) :test #'equalp)))
(when (not (expr-scalar-equivalent-p upfrom (expr-detach-loop-bound to)))
(when (null is-tile-band) (push idx space))
Expand Down
Loading

0 comments on commit 679df19

Please sign in to comment.