Skip to content

Commit

Permalink
Merge pull request gonum#8 from dane-unltd/master
Browse files Browse the repository at this point in the history
Implemented LQ factorization, questions about API
  • Loading branch information
kortschak committed Jan 19, 2014
2 parents 61aa81a + 28cacaa commit 1943cdb
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 0 deletions.
178 changes: 178 additions & 0 deletions mat64/lq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mat64

import (
"github.com/gonum/blas"
"math"
)

type LQFactor struct {
LQ *Dense
lDiag []float64
}

// LQ computes a LQ Decomposition for an m-by-n matrix a with m <= n by Householder
// reflections, the LQ decomposition is an m-by-n orthogonal matrix q and an n-by-n
// upper triangular matrix r so that a = q.r. LQ will panic with ErrShape if m > n.
//
// The LQ decomposition always exists, even if the matrix does not have full rank,
// so LQ will never fail unless m > n. The primary use of the LQ decomposition is
// in the least squares solution of non-square systems of simultaneous linear equations.
// This will fail if LQIsFullRank() returns false. The matrix a is overwritten by the
// decomposition.
func LQ(a *Dense) LQFactor {
// Initialize.
m, n := a.Dims()
if m > n {
panic(ErrShape)
}

lq := &Dense{}
*lq = *a

lDiag := make([]float64, m)
projs := make(Vec, m)

// Main loop.
for k := 0; k < m; k++ {
hh := Vec(lq.RowView(k))[k:]
norm := blasEngine.Dnrm2(len(hh), hh, 1)
lDiag[k] = norm

if norm != 0 {
hhNorm := (norm * math.Sqrt(1-hh[0]/norm))
if hhNorm == 0 {
hh[0] = 0
} else {
// Form k-th Householder vector.
s := 1 / hhNorm
hh[0] -= norm
blasEngine.Dscal(len(hh), s, hh, 1)

// Apply transformation to remaining columns.
if k < m-1 {
*a = *lq
a.View(k+1, k, m-k-1, n-k)
projs = projs[0 : m-k-1]
projs.Mul(a, &hh)

for j := 0; j < m-k-1; j++ {
dst := a.RowView(j)
blasEngine.Daxpy(len(dst), -projs[j], hh, 1, dst, 1)
}
}
}
}
}

return LQFactor{lq, lDiag}
}

// IsFullRank returns whether the L matrix and hence a has full rank.
func (f LQFactor) IsFullRank() bool {
for _, v := range f.lDiag {
if v == 0 {
return false
}
}
return true
}

// L returns the lower triangular factor for the LQ decomposition.
func (f LQFactor) L() *Dense {
lq, lDiag := f.LQ, f.lDiag
m, _ := lq.Dims()
l := NewDense(m, m, nil)
for i, v := range lDiag {
for j := 0; j < m; j++ {
if i < j {
l.Set(j, i, lq.At(j, i))
} else if i == j {
l.Set(j, i, v)
}
}
}
return l
}

// replaces x with Q.x
func (f LQFactor) ApplyQ(x *Dense, trans bool) {
nh, nc := f.LQ.Dims()
m, n := x.Dims()
if m != nc {
panic(ErrShape)
}
proj := make([]float64, n)

if trans {
for k := nh - 1; k >= 0; k-- {
sub := &Dense{}
*sub = *x
hh := f.LQ.RowView(k)[k:]

sub.View(k, 0, m-k, n)

blasEngine.Dgemv(blas.ColMajor, blas.NoTrans, n, m-k, 1,
sub.mat.Data, sub.mat.Stride, hh, 1, 0, proj, 1)
for i := k; i < m; i++ {
row := x.RowView(i)
blasEngine.Daxpy(n, -hh[i-k], proj, 1, row, 1)
}
}
} else {
for k := 0; k < nh; k++ {
sub := &Dense{}
*sub = *x
hh := f.LQ.RowView(k)[k:]

sub.View(k, 0, m-k, n)

blasEngine.Dgemv(blas.ColMajor, blas.NoTrans, n, m-k, 1,
sub.mat.Data, sub.mat.Stride, hh, 1, 0, proj, 1)
for i := k; i < m; i++ {
row := x.RowView(i)
blasEngine.Daxpy(n, -hh[i-k], proj, 1, row, 1)
}
}
}
}

// Solve a computes minimum norm least squares solution of a.x = b where b has as many rows as a.
// A matrix x is returned that minimizes the two norm of Q*R*X-B. Solve will panic
// if a is not full rank.
func (f LQFactor) Solve(b *Dense) (x *Dense) {
lq := f.LQ
lDiag := f.lDiag
m, n := lq.Dims()
bm, bn := b.Dims()
if bm != m {
panic(ErrShape)
}
if !f.IsFullRank() {
panic("mat64: matrix is rank deficient")
}

x = NewDense(n, bn, nil)
xv := new(Dense)
*xv = *x
xv.View(0, 0, bm, bn)
xv.Copy(b)

tau := make([]float64, m)
for i := range tau {
tau[i] = lq.At(i, i)
lq.Set(i, i, lDiag[i])
}
blasEngine.Dtrsm(blas.RowMajor, blas.Left, blas.Lower, blas.NoTrans, blas.NonUnit,
bm, bn, 1, lq.mat.Data, lq.mat.Stride, x.mat.Data, x.mat.Stride)

for i := range tau {
lq.Set(i, i, tau[i])
}
f.ApplyQ(x, true)

return x
}
122 changes: 122 additions & 0 deletions mat64/lq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mat64

import (
check "launchpad.net/gocheck"
"math"
)

func isLowerTriangular(a *Dense) bool {
rows, cols := a.Dims()
for r := 0; r < rows; r++ {
for c := r + 1; c < cols; c++ {
if math.Abs(a.At(r, c)) > 1e-14 {
return false
}
}
}
return true
}

func (s *S) TestLQD(c *check.C) {
for _, test := range []struct {
a [][]float64
name string
}{
{
name: "Square",
a: [][]float64{
{1.3, 2.4, 8.9},
{-2.6, 8.7, 9.1},
{5.6, 5.8, 2.1},
},
},
{
name: "Skinny",
a: [][]float64{
{1.3, 2.4, 8.9},
{-2.6, 8.7, 9.1},
{5.6, 5.8, 2.1},
{19.4, 5.2, -26.1},
},
},
{
name: "Id",
a: [][]float64{
{1, 0, 0},
{0, 1, 0},
{0, 0, 1},
},
},
{
name: "Id",
a: [][]float64{
{0, 0, 2},
{0, 1, 0},
{3, 0, 0},
},
},
{
name: "small",
a: [][]float64{
{1, 1},
{1, 2},
},
},
} {

a := NewDense(flatten(test.a))

at := new(Dense)
at.TCopy(a)

lq := LQ(DenseCopyOf(at))

rows, cols := a.Dims()

Q := NewDense(rows, cols, nil)
for i := 0; i < cols; i++ {
Q.Set(i, i, 1)
}
lq.ApplyQ(Q, true)
l := lq.L()

lt := NewDense(rows, cols, nil)
ltview := new(Dense)
*ltview = *lt
ltview.View(0, 0, cols, cols)
ltview.TCopy(l)
lq.ApplyQ(lt, true)

c.Check(isOrthogonal(Q), check.Equals, true, check.Commentf("Test %v: Q not orthogonal", test.name))
c.Check(a.EqualsApprox(lt, 1e-13), check.Equals, true, check.Commentf("Test %v: Q*R != A", test.name))
c.Check(isLowerTriangular(l), check.Equals, true,
check.Commentf("Test %v: L not lower triangular", test.name))

nrhs := 2
barr := make([]float64, nrhs*cols)
for i := range barr {
barr[i] = float64(i)
}
b := NewDense(cols, nrhs, barr)

x := lq.Solve(b)

bProj := new(Dense)
bProj.Mul(at, x)

c.Check(b.EqualsApprox(bProj, 1e-13), check.Equals, true, check.Commentf("Test %v: A*X != B", test.name))

qr := QR(DenseCopyOf(a))
lambda := qr.Solve(DenseCopyOf(x))

xCheck := new(Dense)
xCheck.Mul(a, lambda)

c.Check(xCheck.EqualsApprox(x, 1e-13), check.Equals, true,
check.Commentf("Test %v: A*lambda != X", test.name))
}
}

0 comments on commit 1943cdb

Please sign in to comment.