Skip to content

Commit

Permalink
Implement Stack and Augment
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jan 20, 2014
1 parent 1f6eb2e commit 5a12150
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
54 changes: 52 additions & 2 deletions mat64/dense.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ var (
_ Uer = matrix
_ Ler = matrix

// _ Stacker = matrix
// _ Augmenter = matrix
_ Stacker = matrix
_ Augmenter = matrix

_ Equaler = matrix
_ ApproxEqualer = matrix
Expand Down Expand Up @@ -863,6 +863,56 @@ func (m *Dense) TCopy(a Matrix) {
*m = w
}

func (m *Dense) Stack(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != bc || m == a || m == b {
panic(ErrShape)
}

if m.isZero() {
m.mat = RawMatrix{
Order: BlasOrder,
Rows: ar + br,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, (ar+br)*ac),
}
} else if ar+br != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}

m.Copy(a)
w := *m
w.View(ar, 0, br, bc)
w.Copy(b)
}

func (m *Dense) Augment(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || m == a || m == b {
panic(ErrShape)
}

if m.isZero() {
m.mat = RawMatrix{
Order: BlasOrder,
Rows: ar,
Cols: ac + bc,
Stride: ac + bc,
Data: use(m.mat.Data, ar*(ac+bc)),
}
} else if ar != m.mat.Rows || ac+bc != m.mat.Cols {
panic(ErrShape)
}

m.Copy(a)
w := *m
w.View(0, ac, br, bc)
w.Copy(b)
}

func (m *Dense) Sum() float64 {
l := m.mat.Cols
var s float64
Expand Down
60 changes: 60 additions & 0 deletions mat64/dense_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,66 @@ func (s *S) TestApply(c *check.C) {
}
}

func (s *S) TestStack(c *check.C) {
for i, test := range []struct {
a, b, e [][]float64
}{
{
[][]float64{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
[][]float64{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
[][]float64{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
},
{
[][]float64{{1, 1, 1}, {1, 1, 1}, {1, 1, 1}},
[][]float64{{1, 1, 1}, {1, 1, 1}, {1, 1, 1}},
[][]float64{{1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}},
},
{
[][]float64{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}},
[][]float64{{0, 1, 0}, {0, 0, 1}, {1, 0, 0}},
[][]float64{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}},
},
} {
a := NewDense(flatten(test.a))
b := NewDense(flatten(test.b))

var s Dense
s.Stack(a, b)

c.Check(s.Equals(NewDense(flatten(test.e))), check.Equals, true, check.Commentf("Test %d: %v stack %v = %v", i, a, b, s))
}
}

func (s *S) TestAugment(c *check.C) {
for i, test := range []struct {
a, b, e [][]float64
}{
{
[][]float64{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
[][]float64{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
[][]float64{{0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0}},
},
{
[][]float64{{1, 1, 1}, {1, 1, 1}, {1, 1, 1}},
[][]float64{{1, 1, 1}, {1, 1, 1}, {1, 1, 1}},
[][]float64{{1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}},
},
{
[][]float64{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}},
[][]float64{{0, 1, 0}, {0, 0, 1}, {1, 0, 0}},
[][]float64{{1, 0, 0, 0, 1, 0}, {0, 1, 0, 0, 0, 1}, {0, 0, 1, 1, 0, 0}},
},
} {
a := NewDense(flatten(test.a))
b := NewDense(flatten(test.b))

var s Dense
s.Augment(a, b)

c.Check(s.Equals(NewDense(flatten(test.e))), check.Equals, true, check.Commentf("Test %d: %v stack %v = %v", i, a, b, s))
}
}

var (
wd *Dense
)
Expand Down

0 comments on commit 5a12150

Please sign in to comment.