Skip to content

Commit 2a200c3

Browse files
author
Cole Scott
committed
Add reduced svd interface and tests (#65)
1 parent 7918824 commit 2a200c3

File tree

4 files changed

+102
-40
lines changed

4 files changed

+102
-40
lines changed

src/high-level/lapack-generics.lisp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
(defgeneric lapack-csd (matrix p q))
1212

13-
(defgeneric lapack-svd (matrix))
13+
(defgeneric lapack-svd (matrix &key reduced))
1414

1515
(defgeneric lapack-ql (matrix)
1616
(:documentation "Find the LAPACK intermediate representation of ql of a matrix"))

src/high-level/lapack-macros.lisp

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -133,43 +133,46 @@
133133

134134
(defmacro def-lapack-svd (class type svd-function &optional real-type)
135135
`(progn
136-
(defmethod svd ((m ,class))
137-
(lapack-svd m))
136+
(defmethod svd ((m ,class) &key reduced)
137+
(lapack-svd m :reduced reduced))
138138

139-
(defmethod lapack-svd ((m ,class))
140-
"Find the SVD of a matrix M. Return (VALUES U SIGMA Vt) where M = U*SIGMA*Vt"
141-
(let ((jobu "A")
142-
(jobvt "A")
143-
(rows (nrows m))
144-
(cols (ncols m))
145-
(a (alexandria:copy-array (storage (if (eql :row-major (order m)) (transpose m) m))))
146-
(lwork -1)
147-
(info 0))
148-
(let ((lda rows)
149-
(s (make-array (min rows cols) :element-type ',(or real-type type)))
150-
(ldu rows)
151-
(ldvt cols)
152-
(work1 (make-array (max 1 lwork) :element-type ',type))
153-
(work nil)
154-
,@(when real-type
155-
`((rwork (make-array (* 5 (min rows cols)) :element-type ',real-type)))))
156-
(let ((u (make-array (* ldu rows) :element-type ',type))
157-
(vt (make-array (* ldvt cols) :element-type ',type)))
158-
;; run it once as a workspace query
159-
(,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt
160-
work1 lwork ,@(when real-type `(rwork)) info)
161-
(setf lwork (round (realpart (aref work1 0))))
162-
(setf work (make-array (max 1 lwork) :element-type ',type))
163-
;; run it again with optimal workspace size
164-
(,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt
165-
work lwork ,@(when real-type `(rwork)) info)
166-
(let ((smat (make-array (* rows cols) :element-type ',(or real-type type))))
167-
(dotimes (i (min rows cols))
168-
(setf (aref smat (column-major-index (list i i) (shape m)))
169-
(aref s i)))
170-
(values (from-array u (list rows rows) :order :column-major)
171-
(from-array smat (list rows cols) :order :column-major)
172-
(from-array vt (list cols cols) :order :column-major)))))))))
139+
(defmethod lapack-svd ((m ,class) &key reduced)
140+
"Find the SVD of a matrix M. Return (VALUES U SIGMA Vt) where M = U*SIGMA*Vt. If REDUCED is non-NIL, return the reduced SVD (where either U or V are just partial isometries and not necessarily unitary matrices)."
141+
(let* ((jobu (if reduced "S" "A"))
142+
(jobvt (if reduced "S" "A"))
143+
(rows (nrows m))
144+
(cols (ncols m))
145+
(a (alexandria:copy-array (storage (if (eql :row-major (order m)) (transpose m) m))))
146+
(lwork -1)
147+
(info 0)
148+
(k (min rows cols))
149+
(u-cols (if reduced k rows))
150+
(vt-rows (if reduced k cols))
151+
(lda rows)
152+
(s (make-array (min rows cols) :element-type ',(or real-type type)))
153+
(ldu rows)
154+
(ldvt vt-rows)
155+
(work1 (make-array (max 1 lwork) :element-type ',type))
156+
(work nil)
157+
,@(when real-type
158+
`((rwork (make-array (* 5 (min rows cols)) :element-type ',real-type)))))
159+
(let ((u (make-array (* ldu rows) :element-type ',type))
160+
(vt (make-array (* ldvt cols) :element-type ',type)))
161+
;; run it once as a workspace query
162+
(,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt
163+
work1 lwork ,@(when real-type `(rwork)) info)
164+
(setf lwork (round (realpart (aref work1 0))))
165+
(setf work (make-array (max 1 lwork) :element-type ',type))
166+
;; run it again with optimal workspace size
167+
(,svd-function jobu jobvt rows cols a lda s u ldu vt ldvt
168+
work lwork ,@(when real-type `(rwork)) info)
169+
(let ((smat (make-array (* u-cols vt-rows) :element-type ',(or real-type type))))
170+
(dotimes (i k)
171+
(setf (aref smat (column-major-index (list i i) (list u-cols vt-rows)))
172+
(aref s i)))
173+
(values (from-array u (list rows u-cols) :order :column-major)
174+
(from-array smat (list u-cols vt-rows) :order :column-major)
175+
(from-array vt (list vt-rows cols) :order :column-major))))))))
173176

174177
;; TODO: This returns only the real parts when with non-complex numbers. Should do something different?
175178
(defmacro def-lapack-eig (class type eig-function &optional real-type)

src/high-level/matrix.lisp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,10 @@ If fast is t then just change order. Fast can cause problems when you want to mu
523523
(declare (ignore matrix))
524524
(error "INVERSE is not defined for the generic matrix type.")))
525525

526-
(defgeneric svd (matrix)
526+
(defgeneric svd (matrix &key reduced)
527527
(:documentation "Find the SVD of a matrix M. Return (VALUES U SIGMA Vt) where M = U*SIGMA*Vt")
528-
(:method ((matrix matrix))
529-
(declare (ignore matrix))
528+
(:method ((matrix matrix) &key reduced)
529+
(declare (ignore matrix reduced))
530530
(error "SVD is not defined for the generic matrix type.")))
531531

532532
(defgeneric qr (matrix)

tests/high-level-tests.lisp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,62 @@
6767
(is (cl:= (magicl:nrows xxx) (magicl:ncols xxx) (expt matrix-dim 3)))
6868
))
6969

70+
71+
(deftest test-svd ()
72+
"Test the full and reduced SVDs."
73+
(labels ((mul-diag-times-gen (diag matrix)
74+
"Returns a newly allocated matrix resulting from the product of DIAG (a diagonal real matrix) with MATRIX (a complex matrix)."
75+
#+ignore
76+
(declare (type matrix diag matrix)
77+
(values matrix))
78+
(let* ((m (magicl:nrows diag))
79+
(k (magicl:ncols matrix))
80+
(result (magicl:empty (list m k))))
81+
(dotimes (i (min m (magicl:ncols diag)) result)
82+
(let ((dii (magicl:tref diag i i)))
83+
(dotimes (j k)
84+
(setf (magicl:tref result i j)
85+
(* dii (magicl:tref matrix i j))))))))
86+
87+
(norm-inf (matrix)
88+
"Return the infinity norm of vec(MATRIX)."
89+
(let ((data (magicl::storage matrix)))
90+
(reduce #'max data :key #'abs)))
91+
92+
(zero-p (matrix &optional (tolerance 1.0e-14))
93+
"Return T if MATRIX is close to zero (within TOLERANCE)."
94+
(< (norm-inf matrix) tolerance))
95+
96+
(check-full-svd (matrix)
97+
"Validate full SVD of MATRIX."
98+
(let ((m (magicl:nrows matrix))
99+
(n (magicl:ncols matrix)))
100+
(multiple-value-bind (u sigma vh)
101+
(magicl:svd matrix)
102+
(is (= (magicl:nrows u) (magicl:ncols u) m))
103+
(is (and (= (magicl:nrows sigma) m) (= (magicl:ncols sigma) n)))
104+
(is (= (magicl:nrows vh) (magicl:ncols vh) n))
105+
(is (zero-p (magicl:- matrix (magicl:@ u (mul-diag-times-gen sigma vh))))))))
106+
107+
(check-reduced-svd (matrix)
108+
"Validate reduced SVD of MATRIX."
109+
(let* ((m (magicl:nrows matrix))
110+
(n (magicl:ncols matrix))
111+
(k (min m n)))
112+
113+
(multiple-value-bind (u sigma vh)
114+
(magicl:svd matrix :reduced t)
115+
(is (and (= (magicl:nrows u) m)
116+
(= (magicl:ncols u) k)))
117+
(is (= (magicl:nrows sigma) (magicl:ncols sigma) k))
118+
(is (and (= (magicl:nrows vh) k)
119+
(= (magicl:ncols vh) n)))
120+
(is (zero-p (magicl:- matrix (magicl:@ u (mul-diag-times-gen sigma vh)))))))))
121+
122+
(let ((tall-thin-matrix (magicl:rand '(8 2))))
123+
(check-full-svd tall-thin-matrix)
124+
(check-reduced-svd tall-thin-matrix))
125+
126+
(let ((short-fat-matrix (magicl:rand '(2 8))))
127+
(check-full-svd short-fat-matrix)
128+
(check-reduced-svd short-fat-matrix))))

0 commit comments

Comments
 (0)