Skip to content

Commit

Permalink
Use MTLCodegen Directly (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Jan 7, 2025
1 parent 6bb5b59 commit 1c098ee
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 27 deletions.
4 changes: 3 additions & 1 deletion source/byoc/caten.byoc.asd
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
:description "BYOC (Bring Your Own Codegen) implements an extension of caten/codegen targeting multiple devices."
:author "hikettei <ichndm@gmail.com>"
:license "MIT"
:depends-on ("caten.codegen" "caten.runtime" "cffi" "flexi-streams" "float-features")
:defsystem-depends-on ("cffi-grovel")
:depends-on ("caten.codegen" "caten.runtime" "cffi" "flexi-streams" "float-features" "babel" "cl-pack")
:components ((:file "lisp")
(:file "native")
(:file "clang")
(:cffi-wrapper-file "helpers/callback" :soname "callback_helper")
(:file "metal")
(:file "cuda")
(:file "llvm")
Expand Down
26 changes: 26 additions & 0 deletions source/byoc/helpers/callback.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
;; Note1: only macOS should need this
;; Note2: create a closure that calls the stored callback (ensure it is thread-safe!)
;; Note3: this thing can be realized with only using cffi?
(c
"#ifdef __APPLE__
#include <stdio.h>
typedef void (*callback_t)(void);
static callback_t stored_cb = NULL;
void* closure_cffi_callback(void (*lisp_callback)(void)) {
stored_cb = lisp_callback;
return (^{ stored_cb(); });
}
#else
#include <stdio.h>
void* closure_cffi_callback(void (*lisp_callback)(void)) {
fprintf(stderr, \"Error(caten/byoc/helper/callback.lisp): Apple Blocks extension not supported on this platform.\n\");
return NULL;
}
#endif
")
91 changes: 67 additions & 24 deletions source/byoc/metal.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

(in-package :caten/byoc/metal)

(defconstant +request-type-compile+ 13)

(defun ensure-foreign-library ()
(load-foreign-library "/usr/lib/libobjc.dylib")
(load-foreign-library "/System/Library/Frameworks/Metal.framework/Metal")
(load-foreign-library "/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
(load-foreign-library "/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
(load-foreign-library "/usr/lib/libSystem.dylib"))

Expand All @@ -24,28 +27,70 @@
(defmacro msg (ptr selector restype &rest args)
`(foreign-funcall "objc_msgSend" :pointer ,ptr :pointer (sel ,selector) ,@args ,restype))
(defun to-ns-str (str) (with-foreign-string (*str str) (msg (objc-getclass "NSString") "stringWithUTF8String:" :pointer :pointer *str)))
;; [TODO] Use MTLCompiler directly to gain signifcant improvement on the compilation time
(defun mtl-compile-source (source)
(flet ((run-cmd (cmd input)
(let* ((process-info (uiop:launch-program cmd :input :stream :output :stream :error-output :stream))
(error-output (uiop:process-info-error-output process-info))
(input-stream (uiop:process-info-input process-info)))
(unwind-protect
(if (stringp input)
(princ input input-stream)
(loop for i across input do (write-byte i input-stream)))
(close input-stream))
(unless (zerop (uiop:wait-process process-info))
(error "Caten[Metal]: Failed to create a metal library:~%~a~%
Compiled with this command: ~a"
(alexandria:read-stream-content-into-string error-output)
(with-output-to-string (out)
(dolist (c cmd) (princ c out) (princ " " out)))))
(alexandria:read-stream-content-into-byte-vector (uiop:process-info-output process-info)))))
(let* ((air (run-cmd '("xcrun" "-sdk" "macosx" "metal" "-x" "metal" "-c" "-" "-o" "-") source))
(lib (run-cmd '("xcrun" "-sdk" "macosx" "metallib" "-" "-o" "-") air)))
(assert (string= "MTLB" (flexi-streams:octets-to-string (subseq lib 0 4))) () "Invalid Metal library. Corrupt XCode?")
lib)))
;; ~~ MTLCompiler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defcfun "MTLCodeGenServiceCreate" :pointer (service-name :string))
(defcfun "MTLCodeGenServiceBuildRequest" :void (cgs :pointer) (unused :pointer) (request-type :int) (request :pointer) (request-len :size) (callback :pointer))
(defcfun "closure_cffi_callback" :pointer (callback :pointer))

(defvar *callback-handler*)

(defcallback callback :void
((blockptr :pointer) (error :int32) (data :pointer) (datalen :size) (errormsg :pointer))
(declare (ignore blockptr))
(assert (eql :ready *callback-handler*) () "*call-back-handler* is not set to :ready.")
(case error
(0
;; offset from beginning to data = header size + warning size
(let* ((octets (loop for i upfrom 0 below datalen collect (mem-aref data :char i)))
(offsets (cl-pack:unpack "<LL" (with-output-to-string (out) (map 'list #'(lambda (x) (princ (code-char x) out)) (subseq octets 8 16))))))
(setf *callback-handler* (cons :succeed (subseq octets offsets)))))
(otherwise
(setf *callback-handler* (cons :failed (foreign-string-to-lisp errormsg)))))
nil)

(defun round-up (n multiple)
(multiple-value-bind (quotient remainder) (truncate n multiple)
(if (zerop remainder) n (* (1+ quotient) multiple))))

(defun make-request-form (src params)
(let* ((src-encoded (babel:string-to-octets src :encoding :utf-8))
(src-padded-len (round-up (1+ (length src-encoded)) 4))
(src-padding-len (- src-padded-len (length src-encoded)))
(src-padded (concatenate
'(vector (unsigned-byte 8))
src-encoded (make-array src-padding-len :element-type '(unsigned-byte 8) :initial-element 0)))
(params-encoded (babel:string-to-octets params :encoding :utf-8))
(params-padded (concatenate '(vector (unsigned-byte 8)) params-encoded (make-array 1 :element-type '(unsigned-byte 8) :initial-element 0)))
(header (cl-pack:pack "<QQ" (length src-padded) (length params-padded)))
(request (concatenate 'string header (babel:octets-to-string src-padded) (babel:octets-to-string params-padded))))
request))

(defun mtl-compile-source (source
&key
(fmodules-cache-path
#+darwin(progn "~/Library/Caches")
#-darwin(progn "~/.cache"))
&aux
(service (MTLCodeGenServiceCreate "caten"))
(params (format nil "-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path=~a -fno-caret-diagnostics" fmodules-cache-path))
(*callback-handler* :ready))
(declare (type foreign-pointer service) (type string source params))
(assert (<= (ctx:getenv :PARALLEL) 1) () "METAL does not support parallel compilation.")
(let ((request (make-request-form source params)))
(with-foreign-string (*request request)
(MTLCodeGenServiceBuildRequest
service (null-pointer) +request-type-compile+
*request (length request) (closure-cffi-callback (get-callback 'callback))))
(assert (consp *callback-handler*) () "*callback-handler* did not receive anything!")
(case (car *callback-handler*)
(:succeed
(let* ((len (length (cdr *callback-handler*)))
(octets (make-array len :element-type '(signed-byte 8) :initial-contents (cdr *callback-handler*))))
(assert (string= "MTLB" (flexi-streams:octets-to-string (subseq octets 0 4))) () "Invalid Metal library. Corrupt XCode?")
(assert (string= "ENDT" (flexi-streams:octets-to-string (subseq octets (- len 4)))) () "Invalid Metal library. Corrupt XCode?")
octets))
(:failed
(error "Failed to compile a metallib:~%~a~%Compiled with this command: ~a" (cdr *callback-handler*) params)))))
;; ~~ Extension ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defclass MetalBuffer (AbstractBuffer) nil)
(defclass MetalRuntime (GraphRuntime) ((device :accessor metal-runtime-device)))
Expand Down Expand Up @@ -93,9 +138,7 @@ Compiled with this command: ~a"
(mem-aref val (caten/codegen/helpers:->cffi-dtype (buffer-dtype buffer)) idx)))

(defclass Metal-Renderer (CStyle-Renderer) ((device :accessor metal-renderer-device)))

(define-auto-scheduler (Metal-Auto-Scheduler ()) :n-global-loop 3)

(define-backend :metal MetalBuffer MetalRuntime Metal-Renderer Metal-Auto-Scheduler t)
;; ~~~ Renderers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defun dtype->mtype (dtype)
Expand Down
2 changes: 1 addition & 1 deletion source/byoc/native.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
(when (getattr item :rendered-object)
(format t "~a"
(with-output-to-string (tmp)
(format tmp "~%[Blueprint: ~A]:~%~A~%Disassembly for ~a:~%```~%" (getattr item :name) (getattr item :rendered-object) (getattr item :name))
;; (format tmp "~%[Blueprint: ~A]:~%~A~%Disassembly for ~a:~%```~%" (getattr item :name) (getattr item :rendered-object) (getattr item :name))
(disassemble (compile nil (getattr item :rendered-object)) :stream tmp)
(format tmp "~%```~%"))))))
(dolist (item items)
Expand Down
2 changes: 1 addition & 1 deletion source/codegen/exprify.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ Return: (values new-for-replacements if-statement[optional])"
idx

))

;; todo
(defun test ()
(mutate-for-as-space
(make-node :RENDER :FOR nil nil
Expand Down
3 changes: 3 additions & 0 deletions source/codegen/jit.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ Applies the JIT compilation for the given Runtime. backend is a keyword defined
(when (null is-jit)
(setf (runtime-buffer-type runtime) buffer-class)
(return-from jit runtime))
;; METAL does not parallel compilation
(when (eql backend :metal)
(assert (<= (ctx:getenv :PARALLEL) 1) () "METAL does not support parallel compilation. Set PARALLEL=0"))
(caten/isl:with-isl-context ;; Note: Need this to ensure isl objected allocated here are not cached and not used by other compiling sessions.
(when (= 2 (ctx:getenv :DOT)) (->dot (runtime-graph runtime) :title "Base Graph"))
(run-type-infer runtime)
Expand Down

0 comments on commit 1c098ee

Please sign in to comment.