Skip to content

Commit de9a190

Browse files
committed
fix lapack's EIG on real matrices
Fixes quil-lang#177. Thanks to @jcguu95 for the investigation.
1 parent bddabb4 commit de9a190

File tree

2 files changed

+68
-11
lines changed

2 files changed

+68
-11
lines changed

examples.lisp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,26 @@
140140
(format t "e^iH~%~a~%" expih)
141141
(format t "det(e^iH) = ~D~%" d)))
142142

143+
(defun ensure-complex (x)
144+
(etypecase x
145+
(complex x)
146+
(real (complex x))))
147+
143148
(defun eig-printing (m)
144-
(multiple-value-bind (vals vects)
145-
(magicl:eig m)
146-
(let ((val-diag (funcall #'magicl:from-diag vals)))
147-
(format t "M~%~a~%" m)
148-
(format t "Eigenvalues LAMBDA~%~a~%" val-diag)
149-
(format t "Eigenvectors V~%~a~%" vects)
150-
(format t "M*V~%~a~%" (magicl:@ m vects))
151-
(format t "V*LAMBDA~%~a~%" (magicl:@ vects val-diag)))))
149+
(let ((m-complex (magicl:.complex m (magicl:zeros (magicl:shape m)
150+
:type (magicl:element-type m)))))
151+
(multiple-value-bind (vals vects)
152+
(magicl:eig m)
153+
(let ((val-diag (magicl:from-diag (mapcar #'ensure-complex vals))))
154+
(format t "M~%~a~%" m)
155+
(format t "Eigenvalues LAMBDA~%~a~%" val-diag)
156+
(format t "Eigenvectors V~%~a~%" vects)
157+
(format t "M*V~%~a~%" (magicl:@ m-complex vects))
158+
(format t "V*LAMBDA~%~a~%" (magicl:@ vects val-diag))
159+
(format t "M*V = V*LAMBDA (within 10^-10)? ~:[no~;yes~]~%"
160+
(magicl:= (magicl:@ m-complex vects)
161+
(magicl:@ vects val-diag)
162+
1d-10))))))
152163

153164
(defun eig-example ()
154165
(let ((m (magicl:from-list (list -2 -1 1 -2 1 1 -9 -3 4) '(3 3) :type 'double-float)))

src/extensions/lapack/lapack-templates.lisp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,6 @@
241241
(from-array smat (list u-cols vt-rows) :input-layout :column-major)
242242
(from-array vt (list vt-rows cols) :input-layout :column-major)))))))
243243

244-
;; TODO: This returns only the real parts when with non-complex
245-
;; numbers. Should do something different?
246244
(defun generate-lapack-eig-for-type (class type eig-function &optional real-type)
247245
` (defmethod lapack-eig ((m ,class))
248246
(policy-cond:with-expectations (> speed safety)
@@ -273,7 +271,55 @@
273271
;; run it again with optimal workspace size
274272
(,eig-function jobvl jobvr rows a rows ,@(if real-type `(w) `(wr wi))
275273
vl 1 vr rows work lwork ,@(when real-type `(rwork)) info)
276-
(values (coerce ,@(if real-type `(w) `(wr)) 'list) (from-array vr (list rows cols) :input-layout :column-major))))))))
274+
,(if real-type
275+
`(values (coerce w 'list)
276+
(from-array vr (list rows cols) :input-layout :column-major))
277+
`(values (cl:map 'list (lambda (a b)
278+
(if (zerop b)
279+
a
280+
(complex a b)))
281+
wr wi)
282+
(let* ((evecs (magicl:zeros (list rows cols) :type '(complex ,type)))
283+
(storage (magicl::storage evecs)))
284+
;; square matrix
285+
(loop :with col-lapack := 0
286+
:with col-result := 0
287+
:with skip := nil
288+
:for zr :across wr
289+
:for zi :across wi
290+
:do (cond
291+
;; real eigenvalue
292+
((zerop zi)
293+
(when skip
294+
(warn "SKIP is T when we reached a real eigenvalue."))
295+
(dotimes (r rows)
296+
;; column-major
297+
(setf (aref storage (+ r (* col-result rows)))
298+
(complex (aref vr (+ r (* col-lapack rows))))))
299+
(incf col-result)
300+
(incf col-lapack))
301+
;; complex eigenvalue with conjugate
302+
(skip
303+
(unless (cl:= skip (- zi))
304+
(error "Reached a non-conjugate eigenvalue"))
305+
(setf skip nil))
306+
;; New complex eigenvalue
307+
(t
308+
;; expect a conjugate in the next iteration
309+
(setf skip zi)
310+
(dotimes (r rows)
311+
;; column-major
312+
(setf (aref storage (+ r (* col-result rows)))
313+
(complex
314+
(aref vr (+ r (* col-lapack rows)))
315+
(aref vr (+ r (* (1+ col-lapack) rows)))))
316+
(setf (aref storage (+ r (* (1+ col-result) rows)))
317+
(complex
318+
(aref vr (+ r (* col-lapack rows)))
319+
(- (aref vr (+ r (* (1+ col-lapack) rows)))))))
320+
(incf col-result 2)
321+
(incf col-lapack 2)))
322+
:finally (return evecs)))))))))))
277323

278324
(defun generate-lapack-hermitian-eig-for-type (class type eig-function real-type)
279325
`(defmethod lapack-hermitian-eig ((m ,class))

0 commit comments

Comments
 (0)