Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Prerequisite for AutoScheduler #362

Merged
merged 57 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
6782d42
gemm
hikettei Dec 20, 2024
44a490f
Tile
hikettei Dec 20, 2024
610316a
search tile
hikettei Dec 20, 2024
0cd518a
Merge branch 'main' into opt-gemm
hikettei Dec 20, 2024
f6bcc57
OpenBLAS
hikettei Dec 20, 2024
253cf6f
90GFLops
hikettei Dec 20, 2024
d4c08c4
wip
hikettei Dec 20, 2024
db52790
wip
hikettei Dec 20, 2024
082ae9d
wip
hikettei Dec 20, 2024
9446987
tile param search
hikettei Dec 20, 2024
6a92698
finish an experiment
hikettei Dec 20, 2024
0e435c1
apply tiling without causing segv
hikettei Dec 20, 2024
e32cec2
tiling all bands
hikettei Dec 20, 2024
c4215bc
things are moved
hikettei Dec 20, 2024
4c2de70
move everything into auto-scheduler
hikettei Dec 21, 2024
785d2eb
asd
hikettei Dec 21, 2024
34656d2
fix: tiling-size related gc error
hikettei Dec 21, 2024
6a0c211
An initial attempt to unroll: tile based
hikettei Dec 21, 2024
0ab186d
moved
hikettei Dec 21, 2024
2ce3075
split
hikettei Dec 21, 2024
b22d20f
.
hikettei Dec 21, 2024
874cd17
moved isl things to auto-scheduler
hikettei Dec 21, 2024
b47fd09
refactor
hikettei Dec 21, 2024
e92725d
synchronize baseline
hikettei Dec 21, 2024
e1e5dc7
updt
hikettei Dec 21, 2024
dfca38d
updt
hikettei Dec 21, 2024
0ed2740
typo
hikettei Dec 21, 2024
0adb742
typo
hikettei Dec 21, 2024
378c310
typo
hikettei Dec 21, 2024
d2de5a1
Tweak
hikettei Dec 21, 2024
cd01fe9
nested unroll directive
hikettei Dec 21, 2024
64b8c7c
Unroll for static and simple case
hikettei Dec 21, 2024
3ac9213
Fold MOD
hikettei Dec 21, 2024
0d4095f
nope
hikettei Dec 21, 2024
642acc9
Initial attempt of unrolling
hikettei Dec 21, 2024
ea829fd
copy and remove UNROLL_PARENT for the reminder
hikettei Dec 21, 2024
b7ba19f
Loop Unrolling is worked
hikettei Dec 21, 2024
6df4461
Unrolling softmax
hikettei Dec 21, 2024
994a615
wip: generating a sketch
hikettei Dec 21, 2024
ca032f8
Identify the sketch
hikettei Dec 21, 2024
4cdec95
TODO
hikettei Dec 21, 2024
1c5096e
remove extra unroll reminder
hikettei Dec 21, 2024
bfb2d11
coincident for elwise
hikettei Dec 22, 2024
771fa79
fix: unroll
hikettei Dec 22, 2024
aad075a
quickload
hikettei Dec 22, 2024
1b56342
MemoryPlanner should produce the same result
hikettei Dec 22, 2024
e77703a
Fix for unroll reminder
hikettei Dec 22, 2024
0e9fff8
try unrolling tiled bands
hikettei Dec 22, 2024
6674d7b
ignore tiled band dims
hikettei Dec 22, 2024
ce60a52
base-item
hikettei Dec 22, 2024
5ba4fa3
remove: caten/polyhedral
hikettei Dec 22, 2024
26a4a66
clean up
hikettei Dec 22, 2024
e63ac07
rem dep
hikettei Dec 22, 2024
546d0c4
caten instead
hikettei Dec 22, 2024
f73c49f
updt
hikettei Dec 22, 2024
04cb131
updt
hikettei Dec 22, 2024
c832f5e
updt
hikettei Dec 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions external/backends/metal.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
(:use :cl :caten/air :cffi :caten/codegen/renderer :caten/codegen/helpers
:caten/codegen/shape-inference :caten/avm :caten/codegen/expr :cl-metal)
(:import-from
:caten/polyhedral
#:define-auto-scheduler
#:make-schedule-options))
:caten/codegen/config
#:define-auto-scheduler))

(in-package :caten/metal)

Expand All @@ -17,7 +16,6 @@

(define-auto-scheduler
(Metal-Auto-Scheduler ())
:cost-functions '(:validity :proximity :coincidence)
:n-global-loop 3)
(define-hook-auto-scheduler (Metal-Renderer Metal-Auto-Scheduler))
(defmethod initialize-instance :after ((metal Metal-Renderer) &rest initargs &key &allow-other-keys)
Expand Down
1 change: 0 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ nav:
- caten/air: packages/caten.air.md
- caten/aasm: packages/caten.aasm.md
- caten/codegen: packages/caten.codegen.md
- caten/polyhedral: packages/caten.polyhedral.md
- caten/apis:
- Overview: packages/caten.apis.md
- Tensor: packages/caten.apis.tensor.md
Expand Down
347 changes: 347 additions & 0 deletions scripts/gemm/gemm.c

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions scripts/gemm/openblas.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// gcc-14 ./scripts/gemm/openblas.c -O3 -I/opt/homebrew/opt/openblas/include -L/opt/homebrew/opt/openblas/lib -lopenblas
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <cblas.h>

int main() {
for (int M = 1; M < 100; M++){
int N = 768;
int K = 1024;
int n_sample = 100;

float *A = (float*)malloc(M * N * sizeof(float));
float *B = (float*)malloc(N * K * sizeof(float));
float *C = (float*)malloc(M * K * sizeof(float));

for (int i = 0; i < M*N; i++) {
A[i] = 1.0f;
}
for (int i = 0; i < N*K; i++) {
B[i] = 1.0f;
}
for (int i = 0; i < M*K; i++){
C[i] = 0.0f;
}

struct timespec start, end;
double elapsed;
clock_gettime(CLOCK_MONOTONIC, &start);
for (int s = 0; s < n_sample; s++) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
M, K, N,
1.0f,
A, N,
B, K,
1.0f,
C, K);
}
clock_gettime(CLOCK_MONOTONIC, &end);
elapsed = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1000000000.0;
double ops = 2.0 * M * N * K * n_sample;
double gflops = ops / (elapsed * 1e9);
printf("M=%d | Execution time (for %d samples): %f GFLOPS\n", M, n_sample, gflops);

free(A);
free(B);
free(C);
}
return 0;
}
108 changes: 108 additions & 0 deletions scripts/gemm/search_tile.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// gcc-14 ./scripts/gemm/search_tile.c -fopenmp -O3 -ffast-math -fopenmp -march=native -fopenmp-simd -fstrict-aliasing -ftree-vectorize
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <omp.h>
#include <arm_neon.h>
// Optimized by AutoScheduler
void gemm(int M, int N, int K,
const float * __restrict A,
const float * __restrict B,
float * __restrict C, int NB, int KB)
{
#pragma omp parallel for collapse(2)
for (int i = 0; i < M; i++) {
for (int j0 = 0; j0 < K; j0 += KB) {
int jMax = (j0 + KB < K) ? (j0 + KB) : K;
int validK = jMax - j0;
int kBlocks = validK / 16;
int kRemain = validK % 16;
for (int n0 = 0; n0 < N; n0 += NB) {
int nMax = (n0 + NB < N) ? (n0 + NB) : N;
for (int kb = 0; kb < kBlocks; kb++) {
int jBase = j0 + kb*16;
float32x4_t c_vec0 = vdupq_n_f32(0.0f);
float32x4_t c_vec1 = vdupq_n_f32(0.0f);
float32x4_t c_vec2 = vdupq_n_f32(0.0f);
float32x4_t c_vec3 = vdupq_n_f32(0.0f);
for (int n = n0; n < nMax; n++) {
float a_val = A[i*N + n];
float32x4_t b_vec0 = vld1q_f32(&B[n*K + jBase + 0]);
float32x4_t b_vec1 = vld1q_f32(&B[n*K + jBase + 4]);
float32x4_t b_vec2 = vld1q_f32(&B[n*K + jBase + 8]);
float32x4_t b_vec3 = vld1q_f32(&B[n*K + jBase + 12]);
float32x4_t a_vec = vdupq_n_f32(a_val);
c_vec0 = vmlaq_f32(c_vec0, a_vec, b_vec0);
c_vec1 = vmlaq_f32(c_vec1, a_vec, b_vec1);
c_vec2 = vmlaq_f32(c_vec2, a_vec, b_vec2);
c_vec3 = vmlaq_f32(c_vec3, a_vec, b_vec3);
}
float tmpC[16];
vst1q_f32(&tmpC[0], c_vec0);
vst1q_f32(&tmpC[4], c_vec1);
vst1q_f32(&tmpC[8], c_vec2);
vst1q_f32(&tmpC[12], c_vec3);
for (int x = 0; x < 16; x++) {
C[i*K + (jBase + x)] += tmpC[x];
}
}
// Reminder
for (int j = j0 + kBlocks*16; j < jMax; j++) {
float sum = 0.0f;
for (int n = n0; n < nMax; n++) {
sum += A[i*N + n] * B[n*K + j];
}
C[i*K + j] += sum;
}
}
}
}
}

int main() {
printf("OMP_GET_MAX_THREADS=%d\n", omp_get_max_threads());
float max = 0.0f;
int best_kb = 0;
int best_mb = 0;
for (int kb = 1; kb < 33; kb+=1){
for (int mb = 1; mb < 33; mb+=1){
int M = 10;
int N = 768;
int K = 1024;
int n_sample = 10;

float *A = (float*)malloc(M * N * sizeof(float));
float *B = (float*)malloc(N * K * sizeof(float));
float *C = (float*)malloc(M * K * sizeof(float));
for (int i = 0; i < M*N; i++) {
A[i] = 1.0f;
}
for (int i = 0; i < N*K; i++) {
B[i] = 1.0f;
}
struct timespec start, end;
double elapsed;
clock_gettime(CLOCK_MONOTONIC, &start);
for (int s = 0; s < n_sample; s++) {
gemm(M, N, K, A, B, C, kb, mb);
}
clock_gettime(CLOCK_MONOTONIC, &end);
elapsed = (end.tv_sec - start.tv_sec)
+ (end.tv_nsec - start.tv_nsec) / 1000000000.0;

double ops = 2.0 * M * N * K * n_sample;
// GFLOPS = ops / (elapsed * 10^9)
double gflops = ops / (elapsed * 1e9);
if (gflops > max){
max = gflops;
best_kb = kb;
best_mb = mb;
}
printf("(%d, %d) | Execution time (for %d samples): %f GFLOPS\n", kb, mb, n_sample, gflops);
free(A);
free(B);
free(C);
}
}
printf("Max GFLOPS: %f (MB, KB)=(%d, %d)\n", max, best_mb, best_kb);
}
9 changes: 6 additions & 3 deletions source/aasm/attrs.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@

(defclass JITAble ()
((_type_relay :initarg :_type_relay)
(_read_views :initform nil :initarg :_read_views)
(_output_type :initform nil :initarg :_output_type)
(_read_views :initform nil :initarg :_read_views) ;; [TODO] Removable
(_output_type :initform nil :initarg :_output_type) ;; [TODO] Removable
(declare-type :initarg :declare-type :initform nil)
(iterations :initarg :iterations :initform nil)
(_lowering_history :initform nil :initarg :_lowering_history))
(_lowering_history :initform nil :initarg :_lowering_history)
;; Metadata for Vectorize
(parent-node-id :initform nil :initarg :parent-node-id)
(unroll-history :initform nil :initarg :unroll-history))
(:documentation "This node is jitable.
- declare-type[boolean] When this option is set to T, it is necessary to declare the types of the variables included in. e.g.:
```
Expand Down
1 change: 1 addition & 0 deletions source/aasm/constant-folding.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
;; [TODO] Logical AND/XOR/OR for threefry2x32
(defsimplifier
(apply-fold-constant :speed 1)
((:Mod ((Const x dtype) (Const y _))) -> (Const (mod x y) dtype))
((:Add ((Const x dtype) (Const y _))) -> (Const (+ x y) dtype))
((:Mul ((Const x dtype) (Const y _))) -> (Const (* x y) dtype))
((:Mul ((Const x dtype) (:Recip ((Const y _))))) -> (Const (/ x y) dtype))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,16 @@
(defpackage :caten/codegen/polyhedral-ast
(:documentation "ISL Polyhedral IR ==> Caten Blueprint IR")
(:use :cl :caten/codegen/expr :caten/codegen/expr-cache :caten/air :caten/codegen/shape-inference :trivia)
(defpackage :caten/codegen/ast-parser
(:documentation "
Transform the Polyhedral IR into the Blueprint IR.
```
[ISL Polyhedral IR] ==> <Polyhedral-AST> ===> Caten Blueprint IR
```
scop.lisp for the opposite things.
")
(:use :cl :caten/codegen/expr :caten/codegen/expr-cache :caten/air :caten/codegen/shape-inference :trivia :caten/codegen/polyhedral-ast)
(:import-from :caten/codegen/unroll :mark-unroll-parent-p :mark-unroll-body-p :parse-unroll-directive)
(:export #:lower-into-bp-from-polyhedral))

(in-package :caten/codegen/polyhedral-ast)
;; ~~ ISL AST <-> Lisp Intermidate Object ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

(eval-when (:execute :compile-toplevel :load-toplevel)
(defstruct (ASTBlock
(:constructor make-block (body)))
(body body :type list))

(defstruct (User
(:constructor make-user (name args)))
"T_name(index)"
(name name :type string) (args args :type list))

(defstruct (ASTFor
(:constructor make-for (idx from to by body)))
(idx idx :type string)
(from from :type Expr)
(to to :type Expr)
(by by :type Expr)
(body body :type (or ASTBlock User ASTFor ASTIF))
(scope :local :type (member :local :global)))

(defstruct (AstIf
(:constructor make-if (condition then-node else-node)))
(condition condition :type Expr)
(then-node then-node :type (or ASTBlock User ASTFOR ASTIF))
(else-node else-node :type (or ASTBlock User ASTFOR ASTIF null))))
(in-package :caten/codegen/ast-parser)

(declaim (ftype (function (cffi:foreign-pointer) t) parse-isl-ast))
(defun parse-isl-ast (ast)
Expand Down Expand Up @@ -73,10 +54,33 @@
(user (parse-isl-ast (isl::%isl-ast-node-mark-get-node ast))))
(typecase user
(AstFor
(when (string= mark "parallel")
(setf (astfor-scope user) :global)))
(cond
((equalp mark "parallel")
(setf (astfor-scope user) :global))
((mark-unroll-parent-p mark)
(let ((body (astfor-body user)))
(when (or
(not (typep body 'ASTFor))
(null (and (astfor-marks body) (every #'mark-unroll-body-p (astfor-marks body)))))
(return-from parse-isl-ast-mark user))
(let* ((n-unroll (parse-unroll-directive mark))
(user (copy-astfor user))
(unrolled (caten/codegen/directive:make-unrolled-body user body n-unroll))
(reminder (caten/codegen/directive:compute-reminder-for-unroll user body n-unroll)))
(setf (astfor-body user) unrolled)
(return-from parse-isl-ast-mark (make-block (list user reminder))))))
((mark-unroll-body-p mark)
;; UNROLL_BODY is triggered by the UNROLL_PARENT. Without it the form is ignored.
(assert (null (astfor-marks user)) () "UNROLL_BODY should be orthogonal with other directives.")
(setf (astfor-marks user) (list mark)))
((equalp mark "TILE_BAND")
(push "TILE_BAND" (astfor-marks user)))
(T
;(warn "mark: ignored the mark ~a for ~a" mark user)
)))
(otherwise
(warn "mark: ignored the mark ~a for ~a" mark user)))
;(warn "mark: ignored the mark ~a for ~a" mark user)
))
user))

(declaim (ftype (function ((or cffi:foreign-pointer isl:ast-node)) (values Expr &optional)) parse-isl-expr))
Expand Down Expand Up @@ -184,14 +188,15 @@
(defun r/endif ()
(make-node :Render :ENDIF nil nil))

(defun create-rendering-graph-nodes (lisp-ast items)
(defun create-rendering-graph-nodes (lisp-ast items &aux (space))
(let ((new-graph))
(labels ((find-user (node-id args)
(let ((node (find (princ-to-string node-id) items
:key (alexandria:compose #'princ-to-string #'node-id)
:test #'equalp)))
(assert node () "~a is not found in ~a" node-id items)
(assert (eql (node-type node) :EXPR))
(setf node (copy-node node))
(let ((base (getattr node :iterations)))
(setf (getattr node :iterations) args)
(if (and (null args) (> (length base) 0))
Expand All @@ -205,11 +210,16 @@
(ematch object
((ASTBlock :body body) (map 'list #'lower body))
((AstFor :idx idx :from upfrom :to to :by by :body body :scope scope)
(push (r/for idx upfrom to by scope) new-graph)
(lower body)
(push (r/endfor idx) new-graph))
;; remove an empty loop
(let ((is-tile-band (find "TILE_BAND" (astfor-marks object) :test #'equalp)))
(when (not (expr-scalar-equivalent-p upfrom (expr-detach-loop-bound to)))
(when (null is-tile-band) (push idx space))
(push (r/for idx upfrom to by scope) new-graph)
(lower body)
(when (null is-tile-band) (setf space (remove idx space :test #'string=)))
(push (r/endfor idx) new-graph))))
((User :name name :args args)
(push (find-user name args) new-graph))
(push (caten/codegen/directive:unroll-expr (reverse space) (find-user name args) object) new-graph))
((AstIf :condition cond :then-node then :else-node else)
(push (r/if cond) new-graph)
(lower then)
Expand Down
Loading
Loading