Skip to content

Commit

Permalink
Merge branch 'main' into opt-gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Dec 20, 2024
2 parents 610316a + 1e742ad commit 0cd518a
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 10 deletions.
17 changes: 17 additions & 0 deletions source/apis/helpers.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,20 @@ Reads and binds attributes from module.
(setf (nth idx list) value)))

(defmethod permute-list ((op list) list) (loop for nth in op collect (nth nth list)))

(defun sym-eql (a b)
(if (and (tensor-p a) (tensor-p b))
(or
(eql a b)
(let* ((g1 (with-no-grad (tensor-lowered-graph a)))
(g2 (with-no-grad (tensor-lowered-graph b)))
(g1 (caten/codegen/expr:make-expr :graph g1 :out (car (last (graph-nodes g1)))))
(g2 (caten/codegen/expr:make-expr :graph g2 :out (car (last (graph-nodes g2))))))
;; Note(hikettei) this could be ridiculously slow if the shape is determined by the tensor!
;; Especially in the ViT Graph
(caten/codegen/expr:expr-scalar-equivalent-p g1 g2)))
(equal a b)))

(defun sym-equal (a b)
(declare (type list a b))
(and (= (length a) (length b)) (every #'sym-eql a b)))
4 changes: 2 additions & 2 deletions source/apis/merge-views.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ for i in range(3):
(loop for m in (make-broadcast-mask (tr-shape tracker) new-shape)
for s in new-shape
if m collect s)))
(when (equal shape-w/o-one (tr-shape tracker))
(when (sym-equal shape-w/o-one (tr-shape tracker))
(return-from tr-reshapeable-p t)))
(when (apply-masked-reshape tracker new-shape)
(return-from tr-reshapeable-p t))
Expand All @@ -351,7 +351,7 @@ for i in range(3):
(loop for m in mask
for s in new-shape
if m collect s)))
(when (equal shape-w/o-one (tr-shape tracker))
(when (sym-equal shape-w/o-one (tr-shape tracker))
(return-from tr-apply-reshape (tr-apply-uprank tracker mask))))
(assert (equal (tr-permute tracker) (range 0 (length (tr-shape tracker)))) () "Trying to reshape the permuted tracker!")
(let ((r (apply-masked-reshape tracker new-shape)))
Expand Down
13 changes: 7 additions & 6 deletions source/codegen/shape-inference.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,14 @@ If the shape inference is successfully done and properly deployed to the target
(when (null (getattr n :_type_relay :allow-undefined t))
(setf (getattr n :_type_relay) type))))))
;; ~~ Loop Collase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defun mergeable-view-p (view shape &aux (shape (if (typep shape 'Expr) shape (expr-const (reveal-buffer shape) :int64))))
(defun mergeable-view-p (g view shape &aux (shape (if (typep shape 'Expr) shape (expr-const (reveal-buffer shape) :int64))))
"Mergeable axis = view is not created."
(when (null view) (return-from mergeable-view-p t))
(when (expr-equal-to shape 1) (return-from mergeable-view-p (fourth view))) ;; Always collapse one as long as they are broadcasted.
(trivia:ematch view
;; antyhing for broadcast, because the strides of broadcasted axes are replaced w/ 0
((list (eql 0) (trivia:guard x (expr-scalar-equivalent-p (expr-const x :int64) shape)) (eql 1) _) t)
;; considering the case: X = |val_15|, shape=a*b (a little heavy, so separated)
((list (eql 0) (trivia:guard x (expr-scalar-equivalent-p (%expr-const g x :int64) shape)) (eql 1) _) t)
(_ nil)))

(defun gather-only-scalars (nodes)
Expand All @@ -288,7 +289,7 @@ If the shape inference is successfully done and properly deployed to the target
(if (or (numberp val) (null (id->value graph val)))
(expr-const val dtype)
;; Merge only scalar path!
(expr-from-graph val (apply #'caten/air:make-graph (graph-nodes graph))))))
(expr-from-graph val (apply #'caten/air:make-graph (gather-only-scalars (graph-nodes graph)))))))

(defstruct Iteration-Space
"
Expand Down Expand Up @@ -353,8 +354,8 @@ gids corresponds for the loop idx in the kernel.
(multiple-value-bind (last-size last-stride last-view last-pd) (apply #'values (car (last ret)))
(if (and
(null no-collapse)
(mergeable-view-p last-view last-size)
(mergeable-view-p view size)
(mergeable-view-p g last-view last-size)
(mergeable-view-p g view size)
(or
(when (expr-equal-to last-stride 0) (eql stride 0))
(expr-scalar-equivalent-p
Expand All @@ -365,7 +366,7 @@ gids corresponds for the loop idx in the kernel.
(setf ret
(append
ret
(list (list (%expr-const g size :int64) (%expr-const g stride :int64) (if (mergeable-view-p view size) nil view) (list nth))))))))
(list (list (%expr-const g size :int64) (%expr-const g stride :int64) (if (mergeable-view-p g view size) nil view) (list nth))))))))
(iteration-space-sync-broadcast
(make-iteration-space
:shape
Expand Down
4 changes: 2 additions & 2 deletions source/test-suite/test-schedule-cache.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ Code2:
(deftest transformer-schedule-cache-count-test
(with-protect-jit
(loop for i upfrom 1 below 6
for expected in `(16 20 23 26 29)
for expected in `(14 18 21 24 27)
for tf = (avm-graph (ctx:with-contextvar (:NO_SCHEDULE_CACHE 0) (compile-transformer i)))
;; [TODO] The number of kernels should be a constant regardless of layers!!
do (ok (<= (count-compiled-kernels tf) expected)
(format nil "(Currently Failing ...) Compiled ~a kernels (expecting ~a)" (count-compiled-kernels tf) expected)))))
(format nil "Compiled ~a kernels (expecting ~a)" (count-compiled-kernels tf) expected)))))

(deftest transformer-schedule-cache-consistency-test
(with-protect-jit
Expand Down
16 changes: 16 additions & 0 deletions source/test-suite/test-scheduler.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@
(caten/codegen/expr-cache:with-expr-cache ()
(loop for item in (gather-kernels schedule)
for count in `(1) do
(let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule)))
(ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type))))))))
(testing "Softmax Batch=Tensor"
(multiple-value-bind (schedule avm) (schedule-with-vars (!softmax (make-tensor `(,(!add (iconst 'n) (iconst 's)) n s))))
(check-kernel schedule 1)
(caten/codegen/expr-cache:with-expr-cache ()
(loop for item in (gather-kernels schedule)
for count in `(4) do
(let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule)))
(ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type))))))))
(testing "Matmul Batch=Tensor"
(multiple-value-bind (schedule avm) (schedule-with-vars (!matmul (make-tensor `(,(!add (iconst 'a) (iconst 'b)) 512 256)) (make-tensor `(256 1024))))
(check-kernel schedule 1)
(caten/codegen/expr-cache:with-expr-cache ()
(loop for item in (gather-kernels schedule)
for count in `(3) do
(let ((bp (caten/codegen/blueprint:lower-schedule-item item (avm-graph avm) schedule)))
(ok (= count (count :FOR bp :key #'node-type)) (format nil "Expected ~a loops, got ~a" count (count :FOR bp :key #'node-type)))))))))

Expand Down

0 comments on commit 0cd518a

Please sign in to comment.