diff --git a/README.adoc b/README.adoc index dd4c454..3c0cce2 100644 --- a/README.adoc +++ b/README.adoc @@ -33,11 +33,9 @@ regarding the estimated model with respect to a set of given input values for th `value`:: The predicted value for the response variable computed using the estimated linear hypothesis function ``h(x)`` with `x` given by `C` input values for the explanatory variables `x = [x~1~, x~2~,...,x~C~]`. -`coefficients`:: Estimated slope coefficients - image:http://latex.codecogs.com/gif.latex?\theta_1,%20\theta_2,%20\theta_3,.%20.%20.,%20\theta_C%20[] +`coefficients`:: Estimated coefficients + image:http://latex.codecogs.com/gif.latex?\theta_0,%20\theta_1,%20\theta_2,%20\theta_3,.%20.%20.,%20\theta_C%20[] of the linear linear hypothesis function ``h(x)``. -`intercept`:: Estimated intercept coefficient image:http://latex.codecogs.com/gif.latex?\theta_0%20[] - of the linear hypothesis function ``h(x)``. Assuming the data consists of documents representing sold house prices with features like number of bedrooms, bathrooms and size etc. we can let predict or validate @@ -80,11 +78,11 @@ And the following may be the response with the estimated price of around $ 581,4 "my_house_price": { "value": 581458.3087492324, "coefficients": [ + 227990.63952712028, 248.92285661317254, -68297.7720278421, 64406.52205356777 - ], - "intercept": 227990.63952712028 + ] } } } @@ -99,11 +97,9 @@ The `linreg_stats` aggregation computes statistics for the estimated linear regr `rss`:: Residual sum of squares as a measure of the discrepancy between the data and the estimated model. The lower the `rss` number, the smaller the error of the prediction, and the better the model. `mse`:: Mean squared error or rather `rss` divided by the number of documents consumed for model estimation. -`coefficients`:: Slope coefficients - image:http://latex.codecogs.com/gif.latex?\theta_1,%20\theta_2,%20\theta_3,.%20.%20.,%20\theta_C%20[] +`coefficients`:: Estimated coefficients + image:http://latex.codecogs.com/gif.latex?\theta_0,%20\theta_1,%20\theta_2,%20\theta_3,.%20.%20.,%20\theta_C%20[] of the linear linear hypothesis function ``h(x)``. -`intercept`:: Intercept coefficient image:http://latex.codecogs.com/gif.latex?\theta_0%20[] - of the linear hypothesis function ``h(x)``. Assuming the data consists of documents representing house prices we can compute statistics for the estimated best fitting linear hypothesis function which predicts house prices based on number of @@ -135,11 +131,11 @@ and the last for the response variable. The above request returns the following "rss": 49523788338938.734, "mse": 63410740510.80504, "coefficients": [ + 47553.18737564783, -100544.0725894584, 45981.15827544966, 309.6013051477475 - ], - "intercept": 47553.18737564783 + ] } } } @@ -180,7 +176,8 @@ Do not forget to restart the node after installing. [frame="all"] |=== | Plugin version | Elasticsearch version | Release date -| https://github.com/scaleborn/elasticsearch-linear-regression/releases/download/5.3.0.1/elasticsearch-linear-regression-5.3.0.1.zip[5.3.0.1] | 5.3.0 | Jun 1, 2017 +| https://github.com/scaleborn/elasticsearch-linear-regression/releases/download/5.3.0.1/elasticsearch-linear-regression-5.3.0.2.zip[5.3.0.2] | 5.3.0 | Jul 16, 2017 +| https://github.com/scaleborn/elasticsearch-linear-regression/releases/download/5.3.0.1/elasticsearch-linear-regression-5.3.0.1.zip[5.3.0.1] | 5.3.0 | Jun 30, 2017 |=== ## Examples @@ -198,7 +195,7 @@ https://github.com/scaleborn/elasticsearch-linear-regression/tree/master/example ./bin/logstash -f house-prices-import.conf .... -The indexed data will have this form: +The indexed documents will have this form: [source,js] -------------------------------------------------- { @@ -250,16 +247,97 @@ $ 650,000 to pay for the desired house in "Morro Bay". "dream_house_price": { "value": 649918.0709489314, "coefficients": [ + 228318.6161854365, 249.02340193904183, -68314.4830871133, 64248.05007337558 - ], - "intercept": 228318.6161854365 + ] } } } -------------------------------------------------- +By using sub aggregations we are able to find out the estimated prices per location: +[source,js] +-------------------------------------------------- +/houses/_search?size=0 +{ + "aggs": { + "locations": { + "terms": { + "field": "location.keyword", + "size": 15 + }, + "aggs": { + "dream_house_price": { + "linreg_predict": { + "fields": ["size", "bedrooms", "bathrooms", "price"], + "inputs": [2000, 3, 2] + } + } + } + } + } +} +-------------------------------------------------- + +The response uncovers that "Arroyo Grande" would be +the most expensive region for our dream house: + +[source,js] +-------------------------------------------------- +{ + "aggregations": { + "locations": { + "buckets": [ + { + "key": "Santa Maria-Orcutt", + "doc_count": 265, + "dream_house_price": { + "value": 256251.9105297585, + "coefficients": [ + 26437.192829649313, + 81.19071633227178, + 6825.9128627023265, + 23477.773223729317 + ] + } + }, + { + "key": "Paso Robles", + "doc_count": 85, + "dream_house_price": { + "value": 365620.0386191703, + "coefficients": [ + 42958.257094706176, + 151.7000907380368, + 6486.477078139843, + -98.91559301451247 + ] + } + }, + ... + { + "key": " Arroyo Grande", + "doc_count": 12, + "dream_house_price": { + "value": 1140196.791331573, + "coefficients": [ + 728566.7474390095, + 1956.6474540196602, + -706891.620925945, + -690495.0006844609 + ] + } + } + ... + ] + } + } +} +-------------------------------------------------- + + ## License Copyright 2017 Scaleborn UG (haftungsbeschränkt). diff --git a/gradle.properties b/gradle.properties index a1b3326..f5b1455 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,4 +4,4 @@ wagon-ssh-external.version=2.10 commons-math3.version=3.6.1 group=org.scaleborn.elasticsearch.plugin name=elasticsearch-linear-regression -version=5.3.0.1 +version=5.3.0.2 diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/InternalPrediction.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/InternalPrediction.java index a1eb17d..36a2fa6 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/InternalPrediction.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/InternalPrediction.java @@ -25,7 +25,7 @@ import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseInternalAggregation; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; /** * Created by mbok on 11.04.17. diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionResults.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionResults.java index 8444358..80fd34d 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionResults.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionResults.java @@ -22,7 +22,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.search.aggregations.InternalAggregation.CommonFields; import org.scaleborn.elasticsearch.linreg.aggregation.support.ModelResults; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; /** * Created by mbok on 11.04.17. diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java index 2279564..b8f0c98 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java @@ -27,7 +27,7 @@ import org.scaleborn.linereg.calculation.statistics.Statistics; import org.scaleborn.linereg.calculation.statistics.StatsCalculator; import org.scaleborn.linereg.calculation.statistics.StatsModel; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; /** * Created by mbok on 21.03.17. diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsResults.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsResults.java index 00884b6..aa3324e 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsResults.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsResults.java @@ -23,7 +23,7 @@ import org.scaleborn.elasticsearch.linreg.aggregation.support.ModelResults; import org.scaleborn.linereg.calculation.statistics.Statistics; import org.scaleborn.linereg.calculation.statistics.Statistics.DefaultStatistics; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; /** * Created by mbok on 07.04.17. diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseInternalAggregation.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseInternalAggregation.java index f3f5922..e52eb19 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseInternalAggregation.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseInternalAggregation.java @@ -30,11 +30,12 @@ import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.scaleborn.linereg.calculation.intercept.InterceptCalculator; -import org.scaleborn.linereg.evaluation.DerivationEquation; -import org.scaleborn.linereg.evaluation.DerivationEquationBuilder; -import org.scaleborn.linereg.evaluation.DerivationEquationSolver; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; -import org.scaleborn.linereg.evaluation.commons.CommonsMathSolver; +import org.scaleborn.linereg.estimation.DerivationEquation; +import org.scaleborn.linereg.estimation.DerivationEquationBuilder; +import org.scaleborn.linereg.estimation.DerivationEquationSolver; +import org.scaleborn.linereg.estimation.DerivationEquationSolver.EstimationException; +import org.scaleborn.linereg.estimation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.commons.CommonsMathSolver; /** * Created by mbok on 07.04.17. @@ -142,9 +143,7 @@ public InternalAggregation doReduce(final List aggregations // return empty result if all samples are null if (aggs.isEmpty()) { - return buildInternalAggregation(this.name, this.featuresCount, null, null, - pipelineAggregators(), - getMetaData()); + return buildEmptyInternalAggregation(); } final S composedSampling = buildSampling(this.featuresCount); @@ -154,7 +153,21 @@ public InternalAggregation doReduce(final List aggregations composedSampling.merge((S) ((BaseInternalAggregation) aggs.get(i)).sampling); } - final M evaluatedResults = evaluateResults(composedSampling); + if (composedSampling.getCount() <= composedSampling.getFeaturesCount()) { + LOGGER.debug( + "Insufficient amount of training data for model estimation, at least {} are required, given {}", + composedSampling.getFeaturesCount() + 1, composedSampling.getCount()); + return buildEmptyInternalAggregation(); + } + + M evaluatedResults = null; + try { + evaluatedResults = evaluateResults(composedSampling); + } catch (final EstimationException e) { + LOGGER.debug( + "Failed to estimate model", e); + return buildEmptyInternalAggregation(); + } LOGGER.debug("Evaluated results: {}", evaluatedResults); return buildInternalAggregation(this.name, this.featuresCount, composedSampling, @@ -162,6 +175,12 @@ public InternalAggregation doReduce(final List aggregations pipelineAggregators(), getMetaData()); } + private InternalAggregation buildEmptyInternalAggregation() { + return buildInternalAggregation(this.name, this.featuresCount, null, null, + pipelineAggregators(), + getMetaData()); + } + protected abstract A buildInternalAggregation(final String name, final int featuresCount, final S linRegSampling, final M results, @@ -171,12 +190,12 @@ protected abstract M buildResults(S composedSampling, SlopeCoefficients slopeCoe double intercept); - private M evaluateResults(final S composedSampling) { - // Linear regression evaluation + private M evaluateResults(final S composedSampling) throws EstimationException { + // Linear regression estimation final DerivationEquation derivationEquation = derivationEquationBuilder .buildDerivationEquation(composedSampling); final SlopeCoefficients slopeCoefficients = derivationEquationSolver - .solveCoefficients(derivationEquation); + .estimateCoefficients(derivationEquation); final M buildResults = buildResults(composedSampling, slopeCoefficients, interceptCalculator.calculate(slopeCoefficients, composedSampling, composedSampling)); return buildResults; diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseSampling.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseSampling.java index 0774ca9..8c56d6c 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseSampling.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseSampling.java @@ -17,7 +17,7 @@ package org.scaleborn.elasticsearch.linreg.aggregation.support; import java.io.IOException; -import org.scaleborn.linereg.evaluation.SlopeCoefficientsSampling.SlopeCoefficientsSamplingProxy; +import org.scaleborn.linereg.estimation.SlopeCoefficientsSampling.SlopeCoefficientsSamplingProxy; import org.scaleborn.linereg.sampling.Sampling.InterceptSampling; import org.scaleborn.linereg.sampling.io.StateInputStream; import org.scaleborn.linereg.sampling.io.StateOutputStream; diff --git a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/ModelResults.java b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/ModelResults.java index 2d575fb..4de3156 100644 --- a/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/ModelResults.java +++ b/src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/ModelResults.java @@ -17,69 +17,52 @@ package org.scaleborn.elasticsearch.linreg.aggregation.support; import java.io.IOException; +import java.util.Arrays; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; -import org.scaleborn.linereg.evaluation.SlopeCoefficients.DefaultSlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; /** * Created by mbok on 07.04.17. */ public class ModelResults implements Writeable, ToXContent { - private SlopeCoefficients slopeCoefficients; - - private double intercept; + private final double[] coefficients; public ModelResults(final SlopeCoefficients slopeCoefficients, final double intercept) { - this.slopeCoefficients = slopeCoefficients; - this.intercept = intercept; + final int slopeLen = slopeCoefficients.getCoefficients().length; + this.coefficients = new double[slopeLen + 1]; + System.arraycopy(slopeCoefficients.getCoefficients(), 0, this.coefficients, 1, slopeLen); + this.coefficients[0] = intercept; } public ModelResults(final StreamInput in) throws IOException { - this.slopeCoefficients = new DefaultSlopeCoefficients(in.readDoubleArray()); - this.intercept = in.readDouble(); + this.coefficients = in.readDoubleArray(); } @Override public void writeTo(final StreamOutput out) throws IOException { - out.writeDoubleArray(this.slopeCoefficients.getCoefficients()); - out.writeDouble(this.intercept); - } - - public SlopeCoefficients getSlopeCoefficients() { - return this.slopeCoefficients; - } - - public void setSlopeCoefficients(final SlopeCoefficients slopeCoefficients) { - this.slopeCoefficients = slopeCoefficients; - } - - public double getIntercept() { - return this.intercept; + out.writeDoubleArray(this.coefficients); } - public void setIntercept(final double intercept) { - this.intercept = intercept; + public double[] getCoefficients() { + return this.coefficients; } - @Override public String toString() { return "ModelResults{" + - "slopeCoefficients=" + this.slopeCoefficients + - ", intercept=" + this.intercept + + "coefficients=" + Arrays.toString(this.coefficients) + '}'; } @Override public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { - builder.array("coefficients", this.getSlopeCoefficients().getCoefficients()); - builder.field("intercept", this.getIntercept()); + builder.array("coefficients", this.coefficients); return builder; } diff --git a/src/main/java/org/scaleborn/linereg/calculation/intercept/InterceptCalculator.java b/src/main/java/org/scaleborn/linereg/calculation/intercept/InterceptCalculator.java index 7ebb1b2..a308912 100644 --- a/src/main/java/org/scaleborn/linereg/calculation/intercept/InterceptCalculator.java +++ b/src/main/java/org/scaleborn/linereg/calculation/intercept/InterceptCalculator.java @@ -16,7 +16,7 @@ package org.scaleborn.linereg.calculation.intercept; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; import org.scaleborn.linereg.sampling.Sampling.InterceptSampling; import org.scaleborn.linereg.sampling.Sampling.SamplingContext; diff --git a/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsModel.java b/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsModel.java index 6bbc2df..d6d2b8b 100644 --- a/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsModel.java +++ b/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsModel.java @@ -16,7 +16,7 @@ package org.scaleborn.linereg.calculation.statistics; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients; /** * Bean for evaluated linear model fitting best the sampled data regarding diff --git a/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsSampling.java b/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsSampling.java index 50f1c7f..96a71ac 100644 --- a/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsSampling.java +++ b/src/main/java/org/scaleborn/linereg/calculation/statistics/StatsSampling.java @@ -17,7 +17,7 @@ package org.scaleborn.linereg.calculation.statistics; import java.io.IOException; -import org.scaleborn.linereg.evaluation.SlopeCoefficientsSampling; +import org.scaleborn.linereg.estimation.SlopeCoefficientsSampling; import org.scaleborn.linereg.sampling.Sampling.ResponseVarianceTermSampling; import org.scaleborn.linereg.sampling.io.StateInputStream; import org.scaleborn.linereg.sampling.io.StateOutputStream; diff --git a/src/main/java/org/scaleborn/linereg/evaluation/DerivationEquation.java b/src/main/java/org/scaleborn/linereg/estimation/DerivationEquation.java similarity index 96% rename from src/main/java/org/scaleborn/linereg/evaluation/DerivationEquation.java rename to src/main/java/org/scaleborn/linereg/estimation/DerivationEquation.java index d0c55da..f47b989 100644 --- a/src/main/java/org/scaleborn/linereg/evaluation/DerivationEquation.java +++ b/src/main/java/org/scaleborn/linereg/estimation/DerivationEquation.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.scaleborn.linereg.evaluation; +package org.scaleborn.linereg.estimation; /** * Represents the derivation equation (divided by 2) build up from the sampled data. diff --git a/src/main/java/org/scaleborn/linereg/evaluation/DerivationEquationBuilder.java b/src/main/java/org/scaleborn/linereg/estimation/DerivationEquationBuilder.java similarity index 91% rename from src/main/java/org/scaleborn/linereg/evaluation/DerivationEquationBuilder.java rename to src/main/java/org/scaleborn/linereg/estimation/DerivationEquationBuilder.java index bacbba8..72dce35 100644 --- a/src/main/java/org/scaleborn/linereg/evaluation/DerivationEquationBuilder.java +++ b/src/main/java/org/scaleborn/linereg/estimation/DerivationEquationBuilder.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.scaleborn.linereg.evaluation; +package org.scaleborn.linereg.estimation; /** * Created by mbok on 17.03.17. @@ -22,7 +22,7 @@ public class DerivationEquationBuilder { public DerivationEquation buildDerivationEquation( - SlopeCoefficientsSampling slopeCoefficientsSampling) { + final SlopeCoefficientsSampling slopeCoefficientsSampling) { return new DerivationEquation() { @Override diff --git a/src/main/java/org/scaleborn/linereg/evaluation/DerivationEquationSolver.java b/src/main/java/org/scaleborn/linereg/estimation/DerivationEquationSolver.java similarity index 63% rename from src/main/java/org/scaleborn/linereg/evaluation/DerivationEquationSolver.java rename to src/main/java/org/scaleborn/linereg/estimation/DerivationEquationSolver.java index 0e0de95..5f56506 100644 --- a/src/main/java/org/scaleborn/linereg/evaluation/DerivationEquationSolver.java +++ b/src/main/java/org/scaleborn/linereg/estimation/DerivationEquationSolver.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.scaleborn.linereg.evaluation; +package org.scaleborn.linereg.estimation; /** * Solves the derivation equation and returns the coefficients for the best fit prediction @@ -24,12 +24,26 @@ public interface DerivationEquationSolver { /** - * Solves the derivation equation and returns the coefficients for the best fit prediction - * equation. + * Solves the derivation equation and returns the estimated coefficients for the best fit + * prediction equation. * * @param eq the derivation equation to solve. The input parameter should remain unchanged for * further calculations. * @return the slope coefficients (excluding the intercept) for the best fit prediction equation */ - SlopeCoefficients solveCoefficients(DerivationEquation eq); + SlopeCoefficients estimateCoefficients(DerivationEquation eq) throws EstimationException; + + /** + * Thrown when estimation fails, usually due to linearly dependent data. + */ + public class EstimationException extends Exception { + + public EstimationException(final String message) { + super(message); + } + + public EstimationException(final String message, final Throwable cause) { + super(message, cause); + } + } } diff --git a/src/main/java/org/scaleborn/linereg/evaluation/SlopeCoefficients.java b/src/main/java/org/scaleborn/linereg/estimation/SlopeCoefficients.java similarity index 89% rename from src/main/java/org/scaleborn/linereg/evaluation/SlopeCoefficients.java rename to src/main/java/org/scaleborn/linereg/estimation/SlopeCoefficients.java index 3791e08..556632c 100644 --- a/src/main/java/org/scaleborn/linereg/evaluation/SlopeCoefficients.java +++ b/src/main/java/org/scaleborn/linereg/estimation/SlopeCoefficients.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.scaleborn.linereg.evaluation; +package org.scaleborn.linereg.estimation; import java.util.Arrays; @@ -35,13 +35,13 @@ public DefaultSlopeCoefficients(final double[] coefficients) { @Override public double[] getCoefficients() { - return coefficients; + return this.coefficients; } @Override public String toString() { return "DefaultSlopeCoefficients{" + - "coefficients=" + Arrays.toString(coefficients) + + "coefficients=" + Arrays.toString(this.coefficients) + '}'; } } diff --git a/src/main/java/org/scaleborn/linereg/evaluation/SlopeCoefficientsSampling.java b/src/main/java/org/scaleborn/linereg/estimation/SlopeCoefficientsSampling.java similarity index 99% rename from src/main/java/org/scaleborn/linereg/evaluation/SlopeCoefficientsSampling.java rename to src/main/java/org/scaleborn/linereg/estimation/SlopeCoefficientsSampling.java index 48ae091..7be3947 100644 --- a/src/main/java/org/scaleborn/linereg/evaluation/SlopeCoefficientsSampling.java +++ b/src/main/java/org/scaleborn/linereg/estimation/SlopeCoefficientsSampling.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.scaleborn.linereg.evaluation; +package org.scaleborn.linereg.estimation; import java.io.IOException; import org.scaleborn.linereg.sampling.Sampling.CoefficientLinearTermSampling; diff --git a/src/main/java/org/scaleborn/linereg/evaluation/commons/CommonsMathSolver.java b/src/main/java/org/scaleborn/linereg/estimation/commons/CommonsMathSolver.java similarity index 57% rename from src/main/java/org/scaleborn/linereg/evaluation/commons/CommonsMathSolver.java rename to src/main/java/org/scaleborn/linereg/estimation/commons/CommonsMathSolver.java index 2f8ccae..26fabe8 100644 --- a/src/main/java/org/scaleborn/linereg/evaluation/commons/CommonsMathSolver.java +++ b/src/main/java/org/scaleborn/linereg/estimation/commons/CommonsMathSolver.java @@ -14,18 +14,19 @@ * limitations under the License. */ -package org.scaleborn.linereg.evaluation.commons; +package org.scaleborn.linereg.estimation.commons; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.CholeskyDecomposition; import org.apache.commons.math3.linear.DecompositionSolver; +import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; -import org.scaleborn.linereg.evaluation.DerivationEquation; -import org.scaleborn.linereg.evaluation.DerivationEquationSolver; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; -import org.scaleborn.linereg.evaluation.SlopeCoefficients.DefaultSlopeCoefficients; +import org.scaleborn.linereg.estimation.DerivationEquation; +import org.scaleborn.linereg.estimation.DerivationEquationSolver; +import org.scaleborn.linereg.estimation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.SlopeCoefficients.DefaultSlopeCoefficients; /** * Solves the coefficient derivation equation using math-commons library and the Cholesky @@ -34,16 +35,17 @@ public class CommonsMathSolver implements DerivationEquationSolver { @Override - public SlopeCoefficients solveCoefficients(final DerivationEquation eq) { - double[][] sourceTriangleMatrix = eq.getCovarianceLowerTriangularMatrix(); + public SlopeCoefficients estimateCoefficients(final DerivationEquation eq) + throws EstimationException { + final double[][] sourceTriangleMatrix = eq.getCovarianceLowerTriangularMatrix(); // Copy matrix and enhance it to a full matrix as expected by CholeskyDecomposition // FIXME: Avoid copy job to speed-up the solving process e.g. by extending the CholeskyDecomposition constructor - int length = sourceTriangleMatrix.length; - double[][] matrix = new double[length][]; + final int length = sourceTriangleMatrix.length; + final double[][] matrix = new double[length][]; for (int i = 0; i < length; i++) { matrix[i] = new double[length]; - double[] s = sourceTriangleMatrix[i]; - double[] t = matrix[i]; + final double[] s = sourceTriangleMatrix[i]; + final double[] t = matrix[i]; for (int j = 0; j <= i; j++) { t[j] = s[j]; } @@ -51,11 +53,15 @@ public SlopeCoefficients solveCoefficients(final DerivationEquation eq) { t[j] = sourceTriangleMatrix[j][i]; } } - RealMatrix coefficients = + final RealMatrix coefficients = new Array2DRowRealMatrix(matrix, false); - DecompositionSolver solver = new CholeskyDecomposition(coefficients).getSolver(); - RealVector constants = new ArrayRealVector(eq.getConstraints(), true); - final RealVector solution = solver.solve(constants); - return new DefaultSlopeCoefficients(solution.toArray()); + try { + final DecompositionSolver solver = new CholeskyDecomposition(coefficients).getSolver(); + final RealVector constants = new ArrayRealVector(eq.getConstraints(), true); + final RealVector solution = solver.solve(constants); + return new DefaultSlopeCoefficients(solution.toArray()); + } catch (final NonPositiveDefiniteMatrixException e) { + throw new EstimationException("Matrix inversion error due to data is linearly dependent", e); + } } } diff --git a/src/test/java/org/scaleborn/linereg/TestModels.java b/src/test/java/org/scaleborn/linereg/TestModels.java index 1715a54..44a3adc 100644 --- a/src/test/java/org/scaleborn/linereg/TestModels.java +++ b/src/test/java/org/scaleborn/linereg/TestModels.java @@ -25,10 +25,11 @@ import org.scaleborn.linereg.calculation.statistics.StatsModel; import org.scaleborn.linereg.calculation.statistics.StatsSampling; import org.scaleborn.linereg.calculation.statistics.StatsSampling.StatsSamplingProxy; -import org.scaleborn.linereg.evaluation.DerivationEquation; -import org.scaleborn.linereg.evaluation.DerivationEquationBuilder; -import org.scaleborn.linereg.evaluation.SlopeCoefficients; -import org.scaleborn.linereg.evaluation.commons.CommonsMathSolver; +import org.scaleborn.linereg.estimation.DerivationEquation; +import org.scaleborn.linereg.estimation.DerivationEquationBuilder; +import org.scaleborn.linereg.estimation.DerivationEquationSolver.EstimationException; +import org.scaleborn.linereg.estimation.SlopeCoefficients; +import org.scaleborn.linereg.estimation.commons.CommonsMathSolver; import org.scaleborn.linereg.sampling.exact.ExactModelSamplingFactory; import org.scaleborn.linereg.sampling.exact.ExactSamplingContext; @@ -105,8 +106,8 @@ public DerivationEquation getEquation() { return this.equation; } - public StatsModel evaluateModel() { - final SlopeCoefficients coefficients = new CommonsMathSolver().solveCoefficients( + public StatsModel evaluateModel() throws EstimationException { + final SlopeCoefficients coefficients = new CommonsMathSolver().estimateCoefficients( this.equation); return new StatsModel(this.statsSampling, coefficients); } diff --git a/src/test/java/org/scaleborn/linereg/calculation/statistics/StatsBuilderTests.java b/src/test/java/org/scaleborn/linereg/calculation/statistics/StatsBuilderTests.java index 3b3a3e5..e56db05 100644 --- a/src/test/java/org/scaleborn/linereg/calculation/statistics/StatsBuilderTests.java +++ b/src/test/java/org/scaleborn/linereg/calculation/statistics/StatsBuilderTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.scaleborn.linereg.TestModels; import org.scaleborn.linereg.TestModels.TestModel; +import org.scaleborn.linereg.estimation.DerivationEquationSolver.EstimationException; /** * Tests for {@link StatsCalculator}. @@ -28,14 +29,14 @@ public class StatsBuilderTests extends ESTestCase { @Test - public void testStats() { + public void testStats() throws EstimationException { testStatsForModel(TestModels.SIMPLE_MODEL_1); testStatsForModel(TestModels.MULTI_FEATURES_2_MODEL_1); testStatsForModel(TestModels.MULTI_FEATURES_3_MODEL_1); testStatsForModel(TestModels.MULTI_FEATURES_6_LONGLEY); } - private void testStatsForModel(final TestModel testModel) { + private void testStatsForModel(final TestModel testModel) throws EstimationException { final StatsModel linearModel = testModel.evaluateModel(); final Statistics statistics = new StatsCalculator().calculate(linearModel); testModel.assertStatistics(statistics); diff --git a/src/test/java/org/scaleborn/linereg/estimation/commons/CommonsMathSolverTests.java b/src/test/java/org/scaleborn/linereg/estimation/commons/CommonsMathSolverTests.java new file mode 100644 index 0000000..66a7441 --- /dev/null +++ b/src/test/java/org/scaleborn/linereg/estimation/commons/CommonsMathSolverTests.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2017 Scaleborn UG, www.scaleborn.com + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.scaleborn.linereg.estimation.commons; + +import org.elasticsearch.test.ESTestCase; +import org.junit.Test; +import org.scaleborn.linereg.TestModels; +import org.scaleborn.linereg.TestModels.TestModel; +import org.scaleborn.linereg.estimation.DerivationEquation; +import org.scaleborn.linereg.estimation.DerivationEquationSolver.EstimationException; + +/** + * Created by mbok on 18.03.17. + */ +public class CommonsMathSolverTests extends ESTestCase { + + /** + * Tests coefficient estimation with one feature variable. + */ + @Test + public void testSimpleRegression() throws EstimationException { + final TestModel model = TestModels.SIMPLE_MODEL_1; + final DerivationEquation equation = model.getEquation(); + final double[] coefficients = new CommonsMathSolver().estimateCoefficients(equation) + .getCoefficients(); + this.logger.info("Evaluated linreg coefficients: {}", coefficients); + model.assertCoefficients(coefficients, 0.0000001); + } + + /** + * Tests coefficient estimation with two feature variables. + */ + @Test + public void testMultipleRegressionWith2Features() throws EstimationException { + final TestModel testModel = TestModels.MULTI_FEATURES_2_MODEL_1; + final double[] coefficients = new CommonsMathSolver() + .estimateCoefficients(testModel.getEquation()) + .getCoefficients(); + this.logger.info("Evaluated linreg coefficients: {}", coefficients); + testModel.assertCoefficients(coefficients, 0.0000001); + } + + /** + * Tests coefficient estimation with three feature variables. + */ + @Test + public void testMultipleRegressionWith3Features() throws EstimationException { + final TestModel testModel = TestModels.MULTI_FEATURES_3_MODEL_1; + final DerivationEquation equation = testModel.getEquation(); + final double[] coefficients = new CommonsMathSolver().estimateCoefficients(equation) + .getCoefficients(); + this.logger.info("Evaluated linreg coefficients: {}", coefficients); + testModel.assertCoefficients(coefficients, 0.0000001); + } + + /** + * Tests coefficient estimation with the reference Longley data set. + */ + @Test + public void testMultipleRegressionWithLongleyDataSet() throws EstimationException { + final TestModel testModel = TestModels.MULTI_FEATURES_6_LONGLEY; + final DerivationEquation equation = testModel.getEquation(); + final double[] coefficients = new CommonsMathSolver().estimateCoefficients(equation) + .getCoefficients(); + this.logger.info("Evaluated linreg coefficients for longley data set: {}", coefficients); + testModel.assertCoefficients(coefficients, 0.0000001); + } +} diff --git a/src/test/java/org/scaleborn/linereg/evaluation/commons/CommonsMathSolverTests.java b/src/test/java/org/scaleborn/linereg/evaluation/commons/CommonsMathSolverTests.java deleted file mode 100644 index 20706a7..0000000 --- a/src/test/java/org/scaleborn/linereg/evaluation/commons/CommonsMathSolverTests.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2017 Scaleborn UG, www.scaleborn.com - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.scaleborn.linereg.evaluation.commons; - -import org.elasticsearch.test.ESTestCase; -import org.junit.Test; -import org.scaleborn.linereg.TestModels; -import org.scaleborn.linereg.TestModels.TestModel; -import org.scaleborn.linereg.evaluation.DerivationEquation; - -/** - * Created by mbok on 18.03.17. - */ -public class CommonsMathSolverTests extends ESTestCase { - - /** - * Tests coefficient evaluation with one feature variable. - */ - @Test - public void testSimpleRegression() { - TestModel model = TestModels.SIMPLE_MODEL_1; - DerivationEquation equation = model.getEquation(); - double[] coefficients = new CommonsMathSolver().solveCoefficients(equation).getCoefficients(); - logger.info("Evaluated linreg coefficients: {}", coefficients); - model.assertCoefficients(coefficients, 0.0000001); - } - - /** - * Tests coefficient evaluation with two feature variables. - */ - @Test - public void testMultipleRegressionWith2Features() { - TestModel testModel = TestModels.MULTI_FEATURES_2_MODEL_1; - double[] coefficients = new CommonsMathSolver().solveCoefficients(testModel.getEquation()) - .getCoefficients(); - logger.info("Evaluated linreg coefficients: {}", coefficients); - testModel.assertCoefficients(coefficients, 0.0000001); - } - - /** - * Tests coefficient evaluation with three feature variables. - */ - @Test - public void testMultipleRegressionWith3Features() { - TestModel testModel = TestModels.MULTI_FEATURES_3_MODEL_1; - DerivationEquation equation = testModel.getEquation(); - double[] coefficients = new CommonsMathSolver().solveCoefficients(equation).getCoefficients(); - logger.info("Evaluated linreg coefficients: {}", coefficients); - testModel.assertCoefficients(coefficients, 0.0000001); - } - - /** - * Tests coefficient evaluation with the reference Longley data set. - */ - @Test - public void testMultipleRegressionWithLongleyDataSet() { - TestModel testModel = TestModels.MULTI_FEATURES_6_LONGLEY; - DerivationEquation equation = testModel.getEquation(); - double[] coefficients = new CommonsMathSolver().solveCoefficients(equation).getCoefficients(); - logger.info("Evaluated linreg coefficients for longley data set: {}", coefficients); - testModel.assertCoefficients(coefficients, 0.0000001); - } -}