Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code cleaning and fix typos in variable names #3

Merged
merged 78 commits into from
Jun 28, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
8cb9da0
Merge remote-tracking branch 'csirobigdata/master'
lynnlangit May 19, 2017
e7a5ca9
updated local copy
lynnlangit May 19, 2017
b24ee23
fix typo in variable name
lynnlangit May 19, 2017
3aee2ee
run code formatting tool
lynnlangit May 19, 2017
a7be1a3
fixed spelling errors
lynnlangit May 19, 2017
92d8e51
update .gitignore for intellij
lynnlangit May 19, 2017
805d272
ignore
lynnlangit May 19, 2017
45039db
ignore
lynnlangit May 19, 2017
b26c38c
ignore
lynnlangit May 19, 2017
38bcb5c
ignore
lynnlangit May 19, 2017
9dba715
ignore
lynnlangit May 19, 2017
c42ef97
ignore
lynnlangit May 19, 2017
fc21e5a
ignore
lynnlangit May 19, 2017
d3823a7
ignore
lynnlangit May 19, 2017
a5413d7
clean code per warnings convert nulls to underscores
lynnlangit May 19, 2017
45ac81d
ignore
lynnlangit May 19, 2017
e42b2a0
i
lynnlangit May 19, 2017
b4eb5d6
i
lynnlangit May 19, 2017
d52aa96
i
lynnlangit May 19, 2017
63b20eb
i
lynnlangit May 19, 2017
b866236
f
lynnlangit May 19, 2017
9206ab8
f
lynnlangit May 19, 2017
6876692
f
lynnlangit May 19, 2017
fee741e
f
lynnlangit May 19, 2017
7c59017
f
lynnlangit May 19, 2017
598788f
f
lynnlangit May 19, 2017
684b58a
f
lynnlangit May 19, 2017
407ba6a
f
lynnlangit May 19, 2017
777ab1c
f
lynnlangit May 19, 2017
bce2e2a
f
lynnlangit May 19, 2017
367fcb5
f
lynnlangit May 19, 2017
874cc66
f
lynnlangit May 19, 2017
56f4293
f
lynnlangit May 19, 2017
38422b6
f
lynnlangit May 19, 2017
698867b
f
lynnlangit May 19, 2017
b3e5250
f
lynnlangit May 19, 2017
f66680d
f
lynnlangit May 19, 2017
715a778
f
lynnlangit May 19, 2017
f987f69
f
lynnlangit May 19, 2017
0c4d031
f
lynnlangit May 19, 2017
99eed1d
f
lynnlangit May 19, 2017
0104064
f
lynnlangit May 19, 2017
535af47
f
lynnlangit May 19, 2017
0ff3bed
f
lynnlangit May 19, 2017
b424db1
f
lynnlangit May 19, 2017
2dc1809
f
lynnlangit May 19, 2017
6171f67
f
lynnlangit May 19, 2017
4d54d6c
f
lynnlangit May 19, 2017
55795ae
f
lynnlangit May 19, 2017
1c9d64f
f
lynnlangit May 19, 2017
342a273
f
lynnlangit May 19, 2017
4ace82c
f
lynnlangit May 19, 2017
0767c32
remove untracked
lynnlangit May 19, 2017
683888e
reformatter
lynnlangit May 19, 2017
04bdbde
remove .iml files
lynnlangit May 19, 2017
54208b8
removed untracked
lynnlangit May 19, 2017
2e5a64f
added method Scaladoc info in widekmeans.scala
lynnlangit May 19, 2017
481d16d
Merge remote-tracking branch 'origin/master'
lynnlangit May 19, 2017
bae0709
removed unused imports
lynnlangit May 19, 2017
543a097
refactor CochranAmeritageTest
lynnlangit May 20, 2017
bc2a3c3
remove comments and refactor method names
lynnlangit May 20, 2017
2afb31f
fixed spelling errors in variable names
lynnlangit May 20, 2017
173a744
removing comments
lynnlangit May 22, 2017
b32afda
remove unused imports project-wide
lynnlangit May 22, 2017
53a7362
fixed spelling of length variable
lynnlangit May 22, 2017
706ff2f
remove unused imports
lynnlangit May 23, 2017
2e8353b
fix more spelling errors
lynnlangit May 23, 2017
9ec0e38
fix spelling and add first new unit test
lynnlangit May 23, 2017
72b0428
remove more unused imports
lynnlangit May 23, 2017
eb11c6f
Worked on readability for the Wide K-Means section
plyte May 26, 2017
73cf2d2
Added more comments
plyte May 26, 2017
93c3a84
temp fix to broken build on k-means
lynnlangit May 28, 2017
bfec349
refactoring ml files for human readability
lynnlangit May 29, 2017
1e30159
fixed variable name
lynnlangit May 31, 2017
c2b0580
Reorganization and Scala Docs added
plyte Jun 1, 2017
4c3a9cf
Continued Additions to Scala docs
plyte Jun 1, 2017
71cddd3
update to fix import issues
lynnlangit Jun 12, 2017
a45fafd
fixed typos from last refactoring in DecisionTrees
lynnlangit Jun 15, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed spelling errors
  • Loading branch information
lynnlangit committed May 19, 2017
commit a7be1a3fef2fc7bc9eb97a52f30bbe082f630d90
64 changes: 32 additions & 32 deletions src/main/scala/au/csiro/variantspark/algo/RandomForest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,18 @@ case class VotingAggregator(val nLabels:Int, val nSamples:Int) {


@SerialVersionUID(2l)
case class RandomForestMember[V](val predictor:PredictiveModelWithImportance[V],
val oobIndexs:Array[Int] = null, val oobPred:Array[Int] = null) {
case class RandomForestMember[V](val predictor:PredictiveModelWithImportance[V],
val oobIndexes:Array[Int] = null, val oobPred:Array[Int] = null) {
}

@SerialVersionUID(2l)
case class RandomForestModel[V](val members: List[RandomForestMember[V]], val labelCount:Int, val oobErrors:List[Double] = List.empty)(implicit canSplit:CanSplit[V]) {

def size = members.size
def trees = members.map(_.predictor)

def oobError:Double = oobErrors.last

def printout() {
trees.zipWithIndex.foreach {
case (tree, index) =>
Expand All @@ -87,40 +87,40 @@ case class RandomForestModel[V](val members: List[RandomForestMember[V]], val la
}

def normalizedVariableImportance(norm:VarImportanceNormalizer = To100ImportanceNormalizer): Map[Long, Double] = norm.normalize(variableImportance)
def variableImportance: Map[Long, Double] = {

def variableImportance: Map[Long, Double] = {
// average the importance of each variable over all trees
// if a variable is not used in a tree it's importance for this tree is assumed to be 0
trees.map(_.variableImportanceAsFastMap).foldLeft(new Long2DoubleOpenHashMap())(_.addAll(_))
.asScala.mapValues(_/size)
}

def predict(data: RDD[V])(implicit ct:ClassTag[V]): Array[Int] = predictIndexed(data.zipWithIndex())

def predictIndexed(indexedData: RDD[(V,Long)])(implicit ct:ClassTag[V]): Array[Int] = predictIndexed(indexedData, indexedData.size)

def predictIndexed(indexedData: RDD[(V,Long)], nSamples:Int)(implicit ct:ClassTag[V]): Array[Int] = {
trees.map(_.predictIndexed(indexedData))
.foldLeft(VotingAggregator(labelCount, nSamples))(_.addVote(_)).predictions
}

}

case class RandomForestParams(
oob:Boolean = true,
nTryFraction:Double = Double.NaN,
nTryFraction:Double = Double.NaN,
bootstrap:Boolean = true,
subsample:Double = Double.NaN,
subsample:Double = Double.NaN,
randomizeEquality:Boolean = false,
seed:Long = defRng.nextLong
) {
def resolveDefaults(nSamples:Int, nVariables:Int):RandomForestParams = {
RandomForestParams(
oob = oob,
oob = oob,
nTryFraction = if (!nTryFraction.isNaN) nTryFraction else Math.sqrt(nVariables.toDouble)/nVariables,
bootstrap = bootstrap,
subsample = if (!subsample.isNaN) subsample else if (bootstrap) 1.0 else 0.666,
randomizeEquality = randomizeEquality,
randomizeEquality = randomizeEquality,
seed = seed
)
}
Expand All @@ -135,12 +135,12 @@ trait RandomForestCallback {
// TODO (Design): Avoid using type cast change design
trait BatchTreeModel[V] {
def batchTrain(indexedData: RDD[(V, Long)], dataType:VariableType, labels: Array[Int], nTryFraction: Double, samples:Seq[Sample]): Seq[PredictiveModelWithImportance[V]]
def batchPredict(indexedData: RDD[(V, Long)], models: Seq[PredictiveModelWithImportance[V]], indexes:Seq[Array[Int]]): Seq[Array[Int]]
def batchPredict(indexedData: RDD[(V, Long)], models: Seq[PredictiveModelWithImportance[V]], indexes:Seq[Array[Int]]): Seq[Array[Int]]
}

object RandomForest {
type ModelBuilderFactory[V] = (DecisionTreeParams, CanSplit[V]) => BatchTreeModel[V]

def wideDecisionTreeBuilder[V](params:DecisionTreeParams, canSplit:CanSplit[V]): BatchTreeModel[V] = {
val decisionTree = new DecisionTree[V](params)(canSplit)
new BatchTreeModel[V]() {
Expand All @@ -149,25 +149,25 @@ object RandomForest {
models.asInstanceOf[Seq[DecisionTreeModel[V]]], indexes)(canSplit)
}
}

val defaultBatchSize = 10
}

class RandomForest[V](params:RandomForestParams=RandomForestParams()
,modelBuilderFactory:RandomForest.ModelBuilderFactory[V] = RandomForest.wideDecisionTreeBuilder[V] _
)(implicit canSplit:CanSplit[V]) extends Logging {
// TODO (Design): This seems like an easiest solution but it make this class

// TODO (Design): This seems like an easiest solution but it make this class
// to keep random state ... perhaps this could be externalised to the implicit random



implicit lazy val rng = new XorShift1024StarRandomGenerator(params.seed)
def train(indexedData: RDD[(V, Long)], dataType: VariableType, labels: Array[Int], nTrees: Int)(implicit callback:RandomForestCallback = null): RandomForestModel[V] =


def train(indexedData: RDD[(V, Long)], dataType: VariableType, labels: Array[Int], nTrees: Int)(implicit callback:RandomForestCallback = null): RandomForestModel[V] =
batchTrain(indexedData, dataType, labels, nTrees, RandomForest.defaultBatchSize)

/**
* TODO (Nice): Make a parameter rather then an extra method
* TODO (Func): Add OOB calculation
Expand All @@ -177,19 +177,19 @@ class RandomForest[V](params:RandomForestParams=RandomForestParams()
require(nTrees > 0)
val nSamples = labels.length
val nVariables = indexedData.count().toInt
val nLabels = labels.max + 1
val nLabels = labels.max + 1
logDebug(s"Data: nSamples:${nSamples}, nVariables: ${nVariables}, nLabels:${nLabels}")
val actualParams = params.resolveDefaults(nSamples, nVariables)
val actualParams = params.resolveDefaults(nSamples, nVariables)
Option(callback).foreach(_.onParamsResolved(actualParams))
logDebug(s"Parameters: ${actualParams}")
logDebug(s"Batch Traning: ${nTrees} with batch size: ${nBatchSize}")
val oobAggregator = if (actualParams.oob) Option(new VotingAggregator(nLabels,nSamples)) else None
val builder = modelBuilderFactory(DecisionTreeParams(seed = rng.nextLong, randomizeEquality = actualParams.randomizeEquality), canSplit)
logDebug(s"Batch Training: ${nTrees} with batch size: ${nBatchSize}")
val oobAggregator = if (actualParams.oob) Option(new VotingAggregator(nLabels,nSamples)) else None

val builder = modelBuilderFactory(DecisionTreeParams(seed = rng.nextLong, randomizeEquality = actualParams.randomizeEquality), canSplit)
val allSamples = Stream.fill(nTrees)(Sample.fraction(nSamples, actualParams.subsample, actualParams.bootstrap))
val (allTrees, errors) = allSamples
.sliding(nBatchSize, nBatchSize)
.flatMap { samplesStream =>
.flatMap { samplesStream =>
time {
val samples = samplesStream.toList
val predictors = builder.batchTrain(indexedData, dataType, labels, actualParams.nTryFraction, samples)
Expand All @@ -200,7 +200,7 @@ class RandomForest[V](params:RandomForestParams=RandomForestParams()
} else predictors.map(RandomForestMember(_))
val oobError = oobAggregator.map { agg =>
members.map { m =>
agg.addVote(m.oobPred, m.oobIndexs)
agg.addVote(m.oobPred, m.oobIndexes)
Metrics.classificatoinError(labels, agg.predictions)
}
}.getOrElse(List.fill(predictors.size)(Double.NaN))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class AnalyzeRFCmd extends ArgsApp with FeatureSourceArgs with Echoable with Log
val samples = featureSource.sampleNames
LoanUtils.withCloseable(CSVWriter.open(outputOobPerTree)) { writer =>
writer.writeRow(samples)
rfModel.members.map(m => m.oobIndexs.zip(m.oobPred).toMap)
rfModel.members.map(m => m.oobIndexes.zip(m.oobPred).toMap)
.map(m => Range(0,samples.size).map(i => m.getOrElse(i, null))).foreach(writer.writeRow)
}
}
Expand Down