|
133 | 133 |
|
134 | 134 | (defmacro def-lapack-svd (class type svd-function &optional real-type)
|
135 | 135 | `(progn
|
136 |
| - (defmethod svd ((m ,class)) |
137 |
| - (lapack-svd m)) |
| 136 | + (defmethod svd ((m ,class) &key reduced) |
| 137 | + (lapack-svd m :reduced reduced)) |
138 | 138 |
|
139 |
| - (defmethod lapack-svd ((m ,class)) |
140 |
| - "Find the SVD of a matrix M. Return (VALUES U SIGMA Vt) where M = U*SIGMA*Vt" |
141 |
| - (let ((jobu "A") |
142 |
| - (jobvt "A") |
143 |
| - (rows (nrows m)) |
144 |
| - (cols (ncols m)) |
145 |
| - (a (alexandria:copy-array (storage (if (eql :row-major (order m)) (transpose m) m)))) |
146 |
| - (lwork -1) |
147 |
| - (info 0)) |
148 |
| - (let ((lda rows) |
149 |
| - (s (make-array (min rows cols) :element-type ',(or real-type type))) |
150 |
| - (ldu rows) |
151 |
| - (ldvt cols) |
152 |
| - (work1 (make-array (max 1 lwork) :element-type ',type)) |
153 |
| - (work nil) |
154 |
| - ,@(when real-type |
155 |
| - `((rwork (make-array (* 5 (min rows cols)) :element-type ',real-type))))) |
156 |
| - (let ((u (make-array (* ldu rows) :element-type ',type)) |
157 |
| - (vt (make-array (* ldvt cols) :element-type ',type))) |
158 |
| - ;; run it once as a workspace query |
159 |
| - (,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt |
160 |
| - work1 lwork ,@(when real-type `(rwork)) info) |
161 |
| - (setf lwork (round (realpart (aref work1 0)))) |
162 |
| - (setf work (make-array (max 1 lwork) :element-type ',type)) |
163 |
| - ;; run it again with optimal workspace size |
164 |
| - (,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt |
165 |
| - work lwork ,@(when real-type `(rwork)) info) |
166 |
| - (let ((smat (make-array (* rows cols) :element-type ',(or real-type type)))) |
167 |
| - (dotimes (i (min rows cols)) |
168 |
| - (setf (aref smat (column-major-index (list i i) (shape m))) |
169 |
| - (aref s i))) |
170 |
| - (values (from-array u (list rows rows) :order :column-major) |
171 |
| - (from-array smat (list rows cols) :order :column-major) |
172 |
| - (from-array vt (list cols cols) :order :column-major))))))))) |
| 139 | + (defmethod lapack-svd ((m ,class) &key reduced) |
| 140 | + "Find the SVD of a matrix M. Return (VALUES U SIGMA Vt) where M = U*SIGMA*Vt. If REDUCED is non-NIL, return the reduced SVD (where either U or V are just partial isometries and not necessarily unitary matrices)." |
| 141 | + (let* ((jobu (if reduced "S" "A")) |
| 142 | + (jobvt (if reduced "S" "A")) |
| 143 | + (rows (nrows m)) |
| 144 | + (cols (ncols m)) |
| 145 | + (a (alexandria:copy-array (storage (if (eql :row-major (order m)) (transpose m) m)))) |
| 146 | + (lwork -1) |
| 147 | + (info 0) |
| 148 | + (k (min rows cols)) |
| 149 | + (u-cols (if reduced k rows)) |
| 150 | + (vt-rows (if reduced k cols)) |
| 151 | + (lda rows) |
| 152 | + (s (make-array (min rows cols) :element-type ',(or real-type type))) |
| 153 | + (ldu rows) |
| 154 | + (ldvt vt-rows) |
| 155 | + (work1 (make-array (max 1 lwork) :element-type ',type)) |
| 156 | + (work nil) |
| 157 | + ,@(when real-type |
| 158 | + `((rwork (make-array (* 5 (min rows cols)) :element-type ',real-type))))) |
| 159 | + (let ((u (make-array (* ldu rows) :element-type ',type)) |
| 160 | + (vt (make-array (* ldvt cols) :element-type ',type))) |
| 161 | + ;; run it once as a workspace query |
| 162 | + (,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt |
| 163 | + work1 lwork ,@(when real-type `(rwork)) info) |
| 164 | + (setf lwork (round (realpart (aref work1 0)))) |
| 165 | + (setf work (make-array (max 1 lwork) :element-type ',type)) |
| 166 | + ;; run it again with optimal workspace size |
| 167 | + (,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt |
| 168 | + work lwork ,@(when real-type `(rwork)) info) |
| 169 | + (let ((smat (make-array (* u-cols vt-rows) :element-type ',(or real-type type)))) |
| 170 | + (dotimes (i k) |
| 171 | + (setf (aref smat (column-major-index (list i i) (list u-cols vt-rows))) |
| 172 | + (aref s i))) |
| 173 | + (values (from-array u (list rows u-cols) :order :column-major) |
| 174 | + (from-array smat (list u-cols vt-rows) :order :column-major) |
| 175 | + (from-array vt (list vt-rows cols) :order :column-major)))))))) |
173 | 176 |
|
174 | 177 | ;; TODO: This returns only the real parts when with non-complex numbers. Should do something different?
|
175 | 178 | (defmacro def-lapack-eig (class type eig-function &optional real-type)
|
|
0 commit comments