Skip to content

Commit

Permalink
Enhancement: Symbolic Loop Merge Dims (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Dec 20, 2024
1 parent fe84007 commit 1ace718
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
6 changes: 3 additions & 3 deletions source/codegen/scop.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ This function returns the BOUND, otherwise returns error.
collect
(if (expr-affine-p (getattr dom :below))
(format nil "0 <= ~(~a~) and ~(~a~)" (getattr dom :idx) (render-expr device (getattr dom :below)))
(multiple-value-bind (expr-id new-p) (stash-expr (expr-detach-loop-bound (getattr dom :below)))
(when new-p (push expr-id extra-symbolics))
(multiple-value-bind (expr-id) (stash-expr (expr-detach-loop-bound (getattr dom :below)))
(push expr-id extra-symbolics)
(format nil "0 <= ~(~a~) < ~(~a~)" (getattr dom :idx) expr-id)))
collect " and ")
(loop for s in (append symbolics extra-symbolics)
Expand All @@ -101,7 +101,7 @@ This function returns the BOUND, otherwise returns error.
(format out " ~a[];~%" (node-id node)))))))
(format out "}"))
idx2domain
extra-symbolics))
(remove-duplicates extra-symbolics :test #'equalp)))

(defun render-domain-from-loops (node domains &aux (device 'Default-Renderer))
(if domains
Expand Down
25 changes: 13 additions & 12 deletions source/codegen/shape-inference.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,17 @@ If the shape inference is successfully done and properly deployed to the target
((list (eql 0) (trivia:guard x (expr-scalar-equivalent-p (expr-const x :int64) shape)) (eql 1) _) t)
(_ nil)))

(defun gather-only-scalars (nodes)
(loop for n in nodes
if (and (= 0 (buffer-nrank (car (relay-writes (read-type-relay n))))))
collect n))

(defun %expr-const (graph value dtype)
(let* ((val (reveal-buffer value))
(load (id->value graph val))
(alloc (when load (id->value graph (car (node-reads load))))))
(let* ((val (reveal-buffer value)))
(if (or (numberp val) (null (id->value graph val)))
(expr-const val dtype)
;; If the value directly represents for the dynamic shape, it is worth to replace it with the node.
;; Otherwise, just a loads the symbol.
(if (and alloc load (eql (node-type alloc) :Alloc) (eql (node-type load) :LOAD)
(symbolp (getattr load :value)) (= 0 (getattr alloc :nrank)))
(expr-from-graph val graph)
(expr-const val dtype)))))
;; Merge only scalar path!
(expr-from-graph val (apply #'caten/air:make-graph (graph-nodes graph))))))

(defstruct Iteration-Space
"
Expand Down Expand Up @@ -356,9 +355,11 @@ gids corresponds for the loop idx in the kernel.
(null no-collapse)
(mergeable-view-p last-view last-size)
(mergeable-view-p view size)
(expr-scalar-equivalent-p
last-stride
(expr-mul (%expr-const g size :int64) (%expr-const g stride :int64))))
(or
(when (expr-equal-to last-stride 0) (eql stride 0))
(expr-scalar-equivalent-p
last-stride
(expr-mul (%expr-const g size :int64) (%expr-const g stride :int64)))))
(setf (nth (1- (length ret)) ret)
(list (expr-mul last-size (%expr-const g size :int64)) (%expr-const g stride :int64) nil (append last-pd (list nth))))
(setf ret
Expand Down
20 changes: 20 additions & 0 deletions source/test-suite/test-scheduler.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@
(let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule)))
(ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type))))))))

(deftest test-symbolic-loop-merge-dims
(testing "B and 10 is merged:"
(multiple-value-bind (schedule avm)
(with-inference-mode ()
(schedule-with-vars (!add (call (Embedding 10 10) (make-tensor `(b 10))) (call (Embedding 10 10) (make-tensor `(b 10))))))
(check-kernel schedule 1)
(caten/codegen/expr-cache:with-expr-cache ()
(loop for item in (gather-kernels schedule)
for count in `(3) do
(let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule)))
(ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type))))))))
(testing "Elementwise is 1D"
(multiple-value-bind (schedule avm) (schedule-with-vars (!relu (make-tensor `(a b c d))))
(check-kernel schedule 1)
(caten/codegen/expr-cache:with-expr-cache ()
(loop for item in (gather-kernels schedule)
for count in `(1) do
(let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule)))
(ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type)))))))))

(deftest test-serialize-reduction-loop
(with-no-grad
(testing "Serialized Reductions should belong to the same loop. (not creating a new inner loop!)"
Expand Down

0 comments on commit 1ace718

Please sign in to comment.