Skip to content

Commit

Permalink
Improve the performance
Browse files Browse the repository at this point in the history
  • Loading branch information
scorebot committed Sep 14, 2024
1 parent 13569bd commit b40f469
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 49 deletions.
7 changes: 5 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ scalacOptions := Seq(
) ++ (CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, scalaMajor)) if scalaMajor <= 11 => Seq.empty
case _ => Seq(
"-optimize",
"-opt:box-unbox",
"-opt:l:method",
"-opt:l:inline",
"-opt-inline-from:**"
)
})

scalacOptions in(Compile, doc) := Seq("-no-link-warnings")

scalaVersion := "2.12.15"
scalaVersion := "2.12.18"

crossScalaVersions := Seq("2.12.15", "2.11.12", "2.13.8", "3.1.3")
crossScalaVersions := Seq("2.12.18", "2.11.12", "2.13.14", "3.1.3")

libraryDependencies ++= {
Seq(
Expand Down
35 changes: 21 additions & 14 deletions src/main/scala/org/pmml4s/metadata/MiningSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,27 @@ trait HasUsageType {
* @param invalidValueTreatment Specifies how invalid input values are handled.
* @param invalidValueReplacement
*/
class MiningField(
val name: String,
val usageType: UsageType,
val opType: Option[OpType],
val importance: Option[Double] = None,
val outliers: OutlierTreatmentMethod = OutlierTreatmentMethod.asIs,
val lowValue: Option[Double] = None,
val highValue: Option[Double] = None,
val missingValueReplacement: Option[Any] = None,
val missingValueTreatment: Option[MissingValueTreatment] = None,
val invalidValueTreatment: InvalidValueTreatment = InvalidValueTreatment.returnInvalid,
val invalidValueReplacement: Option[Any] = None)
extends HasUsageType with PmmlElement

case class MiningField(
val name: String,
val usageType: UsageType,
val opType: Option[OpType],
val importance: Option[Double] = None,
val outliers: OutlierTreatmentMethod = OutlierTreatmentMethod.asIs,
val lowValue: Option[Double] = None,
val highValue: Option[Double] = None,
val missingValueReplacement: Option[Any] = None,
val missingValueTreatment: Option[MissingValueTreatment] = None,
val invalidValueTreatment: InvalidValueTreatment = InvalidValueTreatment.returnInvalid,
val invalidValueReplacement: Option[Any] = None)
extends HasUsageType with PmmlElement {

/* Checks if the mining field has any value preprocess operations defined. */
def isDefault: Boolean = (outliers == OutlierTreatmentMethod.asIs &&
missingValueReplacement.isEmpty &&
!missingValueTreatment.contains(MissingValueTreatment.returnInvalid) &&
invalidValueTreatment == InvalidValueTreatment.returnInvalid &&
invalidValueReplacement.isEmpty)
}
/**
* The MiningSchema is the Gate Keeper for its model element. All data entering a model must pass through the MiningSchema.
* Each model element contains one MiningSchema which lists fields as used in that model. While the MiningSchema contains information
Expand Down
26 changes: 17 additions & 9 deletions src/main/scala/org/pmml4s/model/MiningModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,25 @@ class MiningModel(
// Segmentation
if (isSegmentation) {
val seg = segmentation.get
val segments = seg.segments
import MissingPredictionTreatment._
seg.multipleModelMethod match {
case `selectFirst` => {
val first = seg.segments.find(x => Predication.fire(x.eval(series)))
val first = segments.find(x => Predication.fire(x.eval(series)))
if (first.isDefined) first.get.predict(series) else nullSeries
}
case `selectAll` => {
val all = seg.segments.map(x => if (Predication.fire(x.eval(series))) x.predict(series) else
val all = segments.map(x => if (Predication.fire(x.eval(series))) x.predict(series) else
x.model.nullSeries)
Series.merge(all)
}
case `modelChain` => {
var last: Series = nullSeries
var lastOutputFields: Array[OutputField] = Array.empty
var in = series
for (segment <- seg.segments) {
var i = 0
while (i < segments.length) {
val segment = segments(i)
val out = if (Predication.fire(segment.eval(in))) {
last = segment.predict(in)
lastOutputFields = segment.model.outputFields
Expand All @@ -128,6 +131,7 @@ class MiningModel(
}
segment.id.foreach(x => if (segmentOutputs.contains(x)) outputs.putSegment(x, out))
in = Series.merge(in, out)
i += 1
}

if (output.isDefined) {
Expand All @@ -151,7 +155,7 @@ class MiningModel(
} else last
}
case method => {
val selections = seg.segments.filter(x => Predication.fire(x.eval(series)))
val selections = segments.filter(x => Predication.fire(x.eval(series)))
if (selections.isEmpty) return nullSeries

if (isRegression || ((isClassification || isClustering) && (method == majorityVote || method == weightedMajorityVote))) {
Expand Down Expand Up @@ -181,7 +185,7 @@ class MiningModel(
val realPredictions = predictions.map(_.asInstanceOf[Double])
outputs.predictedValue = method match {
case `average` => {
realPredictions.sum / realPredictions.size.toDouble
realPredictions.sum / realPredictions.length.toDouble
}
case `weightedAverage` => {
MathUtils.product(realPredictions, weights) / weights.sum
Expand All @@ -202,7 +206,7 @@ class MiningModel(
// in some cases, the probabilities could all become 0 if both data types not match.
val dataTypeWanted = if (classes.nonEmpty) Utils.inferDataType(classes(0)) else UnresolvedDataType
val probabilities = Utils.reduceByKey(predictions.zip(weights)).map(x =>
(Utils.toVal(x._1, dataTypeWanted), x._2 / predictions.size)).withDefaultValue(0.0)
(Utils.toVal(x._1, dataTypeWanted), x._2 / predictions.length)).withDefaultValue(0.0)
outputs.setProbabilities(classes.map(x => (x, probabilities(x))).toMap).evalPredictedValueByProbabilities()
}
} else if (isClassification) {
Expand Down Expand Up @@ -235,7 +239,7 @@ class MiningModel(
val matrix = probabilities.transpose
outputs.probabilities = method match {
case `average` => {
classes.zip(matrix.map(_.sum / selections.size)).toMap
classes.zip(matrix.map(_.sum / selections.length)).toMap
}
case `weightedAverage` => {
val sum = weights.sum
Expand All @@ -245,7 +249,7 @@ class MiningModel(
evalPredictedValue = false
outputs.predictedValue = classes.zip(matrix.map(_.max)).maxBy(_._2)
val contributions = probabilities.filter(x => classes.zip(x).maxBy(_._2)._1 == outputs.predictedValue)
classes.zip(contributions.transpose.map(_.sum / contributions.size)).toMap
classes.zip(contributions.transpose.map(_.sum / contributions.length)).toMap
}
case `median` => {
classes.zip(matrix.map(x => MathUtils.median(x))).toMap
Expand All @@ -269,12 +273,16 @@ class MiningModel(
}
}


/** Returns the field of a given name, None if a field with the given name does not exist. */
override def getField(name: String): Option[Field] = {
if (isSegmentation && segmentation.get.multipleModelMethod == MultipleModelMethod.modelChain) {
for (segment <- segmentation.get.segments) {
var i = 0
while (i < segmentation.get.segments.length) {
val segment = segmentation.get.segments(i)
val f = segment.model.output.flatMap(_.getField(name))
if (f.isDefined) return f
i += 1
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/org/pmml4s/model/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ abstract class Model extends HasParent
} else false

newValues(idx) = if (missing) {
if (mf.missingValueTreatment == Some(MissingValueTreatment.returnInvalid)) {
if (mf.missingValueTreatment.contains(MissingValueTreatment.returnInvalid)) {
return (series, true)
}

Expand Down
59 changes: 37 additions & 22 deletions src/main/scala/org/pmml4s/model/TreeModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,32 @@ class TreeModel(
/** Model element type. */
override def modelElement: ModelElement = ModelElement.TreeModel

// Optimize the ensemble tree model, ignore the data prepare if it's identical to the parent model
private val ignoreDataPrepare = if (isSubModel && localTransformations.isEmpty) {
var ignore = true
var i = 0
val miningSchemaParent = parent.miningSchema
val len = miningSchema.inputMiningFields.length
while (i < len && ignore) {
val mf = miningSchema.inputMiningFields(i)
if (!mf.isDefault) {
val mfParent = miningSchemaParent.get(mf.name)
if (mfParent.isDefined && mf != mfParent.get) {
ignore = false
}
}
i += 1
}
ignore
} else false

/** Predicts values for a given data series using the model loaded. */
override def predict(values: Series): Series = {
val (series, returnInvalid) = prepare(values)
val (series, returnInvalid) = if (ignoreDataPrepare) (values, false) else prepare(values)
if (returnInvalid) {
return nullSeries
}

import MissingValueStrategy._
import NoTrueChildStrategy._
import Predication._

val outputs = createOutputs()

// The root node could be leaf
Expand All @@ -66,24 +81,24 @@ class TreeModel(
var done = false
while (!done && selected.isSplit) {
var child: Node = null
var r = FALSE
var r = Predication.FALSE
var hit = false
var unknown = false
var i = 0
while (i < selected.children.length && !hit) {
val c = selected.children(i)
c.eval(series) match {
case TRUE => {
r = TRUE
case Predication.TRUE => {
r = Predication.TRUE
child = c
hit = true
}
case SURROGATE => {
r = SURROGATE
case Predication.SURROGATE => {
r = Predication.SURROGATE
child = c
hit = true
}
case UNKNOWN => {
case Predication.UNKNOWN => {
unknown = true
}
case _ =>
Expand All @@ -92,26 +107,26 @@ class TreeModel(
}

if (!hit) {
r = if (unknown) UNKNOWN else FALSE
r = if (unknown) Predication.UNKNOWN else Predication.FALSE
}

if (r == SURROGATE) {
if (r == Predication.SURROGATE) {
numMissingCount += 1
}

if (r == UNKNOWN) {
if (r == Predication.UNKNOWN) {
missingValueStrategy match {
case `lastPrediction` => {
case MissingValueStrategy.`lastPrediction` => {
finalNode = Some(selected)
done = true
}
case `nullPrediction` =>
case MissingValueStrategy.`nullPrediction` =>
done = true
case `defaultChild` => {
case MissingValueStrategy.`defaultChild` => {
child = selected.defaultChildNode.orNull
numMissingCount += 1
}
case `weightedConfidence` => if (isClassification) {
case MissingValueStrategy.`weightedConfidence` => if (isClassification) {
val total = selected.recordCount.getOrElse(Double.NaN)
val candidates = selected.children.filter { x => x.eval(series) == UNKNOWN }
var max = 0.0
Expand All @@ -130,7 +145,7 @@ class TreeModel(

done = true
}
case `aggregateNodes` => if (isClassification) {
case MissingValueStrategy.`aggregateNodes` => if (isClassification) {
val leaves = mutable.HashSet.empty[Node]
traverseLeaves(selected, series, leaves)
if (leaves.nonEmpty) {
Expand All @@ -157,15 +172,15 @@ class TreeModel(

done = true
}
case `none` =>
case MissingValueStrategy.`none` =>
}
}

// Handling the situation where scoring cannot continue
if (child == null && outputs.predictedValue == null) {
noTrueChildStrategy match {
case `returnNullPrediction` => done = true
case `returnLastPrediction` => {
case NoTrueChildStrategy.`returnNullPrediction` => done = true
case NoTrueChildStrategy.`returnLastPrediction` => {
finalNode = Some(selected)
done = true
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/org/pmml4s/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ object Utils {

def nonMissing(value: Any): Boolean = !isMissing(value)

def isMissing(value: Double): Boolean = value != value
@inline def isMissing(value: Double): Boolean = value != value

def nonMissing(value: Double): Boolean = value == value

Expand Down

0 comments on commit b40f469

Please sign in to comment.