Skip to content

Commit

Permalink
fix 0-dim mean and variance
Browse files Browse the repository at this point in the history
  • Loading branch information
John Canny committed Oct 19, 2019
1 parent a0bf328 commit 96bed3a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<groupId>BIDMat</groupId>
<artifactId>BIDMat</artifactId>
<packaging>jar</packaging>
<version>2.1.9-cuda9.2</version>
<version>2.1.10-cuda9.2</version>
<name>BIDMat</name>
<description>BIDMat performs matrix operations</description>
<properties>
Expand Down
13 changes: 11 additions & 2 deletions src/main/scala/BIDMat/DMat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1411,13 +1411,22 @@ case class DMat(dims0:Array[Int], val data:Array[Double]) extends DenseMat[Doubl

override def sum(inds:IMat):DMat = reduce(inds.data, SciFunctions.sum, "sum")
override def prod(inds:IMat):DMat = reduce(inds.data, SciFunctions.prod, "prod")
override def mean(inds:IMat):DMat = reduce(inds.data, SciFunctions.mean, "mean")
override def variance(inds:IMat):DMat = reduce(inds.data, SciFunctions.variance, "variance")
override def maxi(inds:IMat):DMat = reduce(inds.data, SciFunctions.maxi, "maxi")
override def mini(inds:IMat):DMat = reduce(inds.data, SciFunctions.mini, "mini")
override def amax(inds:IMat):DMat = reduce(inds.data, SciFunctions.maxi, "amax")
override def amin(inds:IMat):DMat = reduce(inds.data, SciFunctions.mini, "amin")

override def mean(inds:IMat):DMat = {val m = this.sum(inds);
m ~ m *@ (1.0/SciFunctions.prod(this.dims(inds)).v);
m}
override def variance(inds:IMat):DMat = {val m = this.sum(inds);
val n = SciFunctions.prod(this.dims(inds)).v;
m ~ m *@ (1.0/n)
val a = this - m;
a ~ a *@ a
val v = a.sum(inds);
v ~ v *@ (1.0/n);
v}

override def * (b : Double) = fDMult(DMat.delem(b), null)
override def + (b : Double) = ddMatOpScalarv(b, DMat.vecAddFun, null)
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/BIDMat/FMat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,18 @@ case class FMat(dims0:Array[Int], val data:Array[Float]) extends DenseMat[Float]
override def mini(inds:IMat):FMat = reduce(inds.data, SciFunctions.mini, reduceTensorFloat, FMat.CBLASop.op_min, "mini")
override def amax(inds:IMat):FMat = reduce(inds.data, SciFunctions.maxi, reduceTensorFloat, FMat.CBLASop.op_max, "amax")
override def amin(inds:IMat):FMat = reduce(inds.data, SciFunctions.mini, reduceTensorFloat, FMat.CBLASop.op_min, "amin")
override def mean(inds:IMat):FMat = reduce(inds.data, SciFunctions.mean, "mean")
override def variance(inds:IMat):FMat = reduce(inds.data, SciFunctions.variance, "variance")

override def mean(inds:IMat):FMat = {val m = this.sum(inds);
m ~ m *@ (1f/SciFunctions.prod(this.dims(inds)).v);
m}
override def variance(inds:IMat):FMat = {val m = this.sum(inds);
val n = SciFunctions.prod(this.dims(inds)).v;
m ~ m *@ (1f/n)
val a = this - m;
a ~ a *@ a
val v = a.sum(inds);
v ~ v *@ (1f/n);
v}

def fDMultHelper(a:FMat, out:FMat, istart:Int, iend:Int) = {
var i = istart
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/BIDMat/GDMat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1187,13 +1187,23 @@ class GDMat(dims0:Array[Int], @transient var pdata:Pointer, val realsize:Long) e

override def sum(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => GDFunctions.sum(a,dir,null), "sum");
override def prod(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => GDFunctions.prod(a,dir,null), "prod");
override def mean(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => SciFunctions.mean(a,dir), "mean")
override def variance(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => SciFunctions.variance(a,dir), "variance")
override def maxi(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => GDFunctions.maxi(a,dir,null), "maxi")
override def mini(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => GDFunctions.mini(a,dir,null), "mini")
override def amax(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => GDFunctions.maxi(a,dir,null), "amax")
override def amin(inds:IMat):DMat = reduce(inds.data, (a:GDMat, dir:Int) => GDFunctions.mini(a,dir,null), "amin")

override def mean(inds:IMat):DMat = {val m = this.sum(inds);
m ~ m *@ (1.0/SciFunctions.prod(this.dims(inds)).v);
m}
override def variance(inds:IMat):DMat = {val m = this.sum(inds);
val n = SciFunctions.prod(this.dims(inds)).v;
m ~ m *@ (1.0/n)
val a = this - m;
a ~ a *@ a
val v = a.sum(inds);
v ~ v *@ (1.0/n);
v}

override def * (a : DMat) = GMult(GDMat(a), null)
override def * (a : SDMat) = GSMult(GSDMat(a), null)
override def *^ (a : DMat) = GMultT(GDMat(a), null)
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/BIDMat/GMat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1727,8 +1727,18 @@ class GMat(dims0:Array[Int], @transient var pdata:Pointer, val realsize:Long) ex
override def mini(inds:IMat):FMat = reduce(inds.data, (a:GMat, dir:Int) => GFunctions.mini(a,dir,null), CUMAT.minTensor, "mini")
override def amax(inds:IMat):FMat = reduce(inds.data, (a:GMat, dir:Int) => GFunctions.maxi(a,dir,null), CUMAT.maxTensor, "amax")
override def amin(inds:IMat):FMat = reduce(inds.data, (a:GMat, dir:Int) => GFunctions.mini(a,dir,null), CUMAT.minTensor,"amin")
override def mean(inds:IMat):FMat = reduce(inds.data, (a:GMat, dir:Int) => SciFunctions.mean(a,dir), "mean")
override def variance(inds:IMat):FMat = reduce(inds.data, (a:GMat, dir:Int) => SciFunctions.variance(a,dir), "variance")

override def mean(inds:IMat):FMat = {val m = this.sum(inds);
m ~ m *@ (1f/SciFunctions.prod(this.dims(inds)).v);
m}
override def variance(inds:IMat):FMat = {val m = this.sum(inds);
val n = SciFunctions.prod(this.dims(inds)).v;
m ~ m *@ (1f/n)
val a = this - m;
a ~ a *@ a
val v = a.sum(inds);
v ~ v *@ (1f/n);
v}

override def * (a : FMat) = GMult(GMat(a), null)
override def * (a : SMat) = GSMult(GSMat(a), null)
Expand Down

0 comments on commit 96bed3a

Please sign in to comment.