From 1ace718ce5324fe158b1704f308fcc367f3981cc Mon Sep 17 00:00:00 2001 From: hikettei <88639579+hikettei@users.noreply.github.com> Date: Fri, 20 Dec 2024 16:07:43 +0900 Subject: [PATCH] Enhancement: Symbolic Loop Merge Dims (#358) --- source/codegen/scop.lisp | 6 +++--- source/codegen/shape-inference.lisp | 25 +++++++++++++------------ source/test-suite/test-scheduler.lisp | 20 ++++++++++++++++++++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/source/codegen/scop.lisp b/source/codegen/scop.lisp index 8918bcec2..8a2eba728 100644 --- a/source/codegen/scop.lisp +++ b/source/codegen/scop.lisp @@ -91,8 +91,8 @@ This function returns the BOUND, otherwise returns error. collect (if (expr-affine-p (getattr dom :below)) (format nil "0 <= ~(~a~) and ~(~a~)" (getattr dom :idx) (render-expr device (getattr dom :below))) - (multiple-value-bind (expr-id new-p) (stash-expr (expr-detach-loop-bound (getattr dom :below))) - (when new-p (push expr-id extra-symbolics)) + (multiple-value-bind (expr-id) (stash-expr (expr-detach-loop-bound (getattr dom :below))) + (push expr-id extra-symbolics) (format nil "0 <= ~(~a~) < ~(~a~)" (getattr dom :idx) expr-id))) collect " and ") (loop for s in (append symbolics extra-symbolics) @@ -101,7 +101,7 @@ This function returns the BOUND, otherwise returns error. (format out " ~a[];~%" (node-id node))))))) (format out "}")) idx2domain - extra-symbolics)) + (remove-duplicates extra-symbolics :test #'equalp))) (defun render-domain-from-loops (node domains &aux (device 'Default-Renderer)) (if domains diff --git a/source/codegen/shape-inference.lisp b/source/codegen/shape-inference.lisp index e4dbb34d0..a0eac62ff 100644 --- a/source/codegen/shape-inference.lisp +++ b/source/codegen/shape-inference.lisp @@ -278,18 +278,17 @@ If the shape inference is successfully done and properly deployed to the target ((list (eql 0) (trivia:guard x (expr-scalar-equivalent-p (expr-const x :int64) shape)) (eql 1) _) t) (_ nil))) +(defun gather-only-scalars (nodes) + (loop for n in nodes + if (and (= 0 (buffer-nrank (car (relay-writes (read-type-relay n)))))) + collect n)) + (defun %expr-const (graph value dtype) - (let* ((val (reveal-buffer value)) - (load (id->value graph val)) - (alloc (when load (id->value graph (car (node-reads load)))))) + (let* ((val (reveal-buffer value))) (if (or (numberp val) (null (id->value graph val))) (expr-const val dtype) - ;; If the value directly represents for the dynamic shape, it is worth to replace it with the node. - ;; Otherwise, just a loads the symbol. - (if (and alloc load (eql (node-type alloc) :Alloc) (eql (node-type load) :LOAD) - (symbolp (getattr load :value)) (= 0 (getattr alloc :nrank))) - (expr-from-graph val graph) - (expr-const val dtype))))) + ;; Merge only scalar path! + (expr-from-graph val (apply #'caten/air:make-graph (graph-nodes graph)))))) (defstruct Iteration-Space " @@ -356,9 +355,11 @@ gids corresponds for the loop idx in the kernel. (null no-collapse) (mergeable-view-p last-view last-size) (mergeable-view-p view size) - (expr-scalar-equivalent-p - last-stride - (expr-mul (%expr-const g size :int64) (%expr-const g stride :int64)))) + (or + (when (expr-equal-to last-stride 0) (eql stride 0)) + (expr-scalar-equivalent-p + last-stride + (expr-mul (%expr-const g size :int64) (%expr-const g stride :int64))))) (setf (nth (1- (length ret)) ret) (list (expr-mul last-size (%expr-const g size :int64)) (%expr-const g stride :int64) nil (append last-pd (list nth)))) (setf ret diff --git a/source/test-suite/test-scheduler.lisp b/source/test-suite/test-scheduler.lisp index c699e1b95..865d8b12d 100644 --- a/source/test-suite/test-scheduler.lisp +++ b/source/test-suite/test-scheduler.lisp @@ -44,6 +44,26 @@ (let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule))) (ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type)))))))) +(deftest test-symbolic-loop-merge-dims + (testing "B and 10 is merged:" + (multiple-value-bind (schedule avm) + (with-inference-mode () + (schedule-with-vars (!add (call (Embedding 10 10) (make-tensor `(b 10))) (call (Embedding 10 10) (make-tensor `(b 10)))))) + (check-kernel schedule 1) + (caten/codegen/expr-cache:with-expr-cache () + (loop for item in (gather-kernels schedule) + for count in `(3) do + (let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule))) + (ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type)))))))) + (testing "Elementwise is 1D" + (multiple-value-bind (schedule avm) (schedule-with-vars (!relu (make-tensor `(a b c d)))) + (check-kernel schedule 1) + (caten/codegen/expr-cache:with-expr-cache () + (loop for item in (gather-kernels schedule) + for count in `(1) do + (let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule))) + (ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type))))))))) + (deftest test-serialize-reduction-loop (with-no-grad (testing "Serialized Reductions should belong to the same loop. (not creating a new inner loop!)"