Skip to content

Commit

Permalink
Use Adagrad optimiser for Linear regression by default (opensearch-pr…
Browse files Browse the repository at this point in the history
…oject#3291)

* Use AdaGrad optimiser by default in Liner Resgression

Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>

* Added issue link in the code comment as a reference.

Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>

* Apply Spotless

Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>

---------

Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
  • Loading branch information
rithin-pullela-aws authored Dec 30, 2024
1 parent ff85b12 commit f323141
Showing 1 changed file with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
public class LinearRegression implements Trainable, Predictable {
public static final String VERSION = "1.0.0";
private static final LinearRegressionParams.ObjectiveType DEFAULT_OBJECTIVE_TYPE = LinearRegressionParams.ObjectiveType.SQUARED_LOSS;
private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.SIMPLE_SGD;
private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.ADA_GRAD;
private static final double DEFAULT_LEARNING_RATE = 0.01;
// Momentum
private static final double DEFAULT_MOMENTUM_FACTOR = 0;
Expand Down Expand Up @@ -134,15 +134,15 @@ private void createOptimiser() {
break;
}
switch (optimizerType) {
case SIMPLE_SGD:
optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum);
break;
case LINEAR_DECAY_SGD:
optimiser = SGD.getLinearDecaySGD(learningRate, momentumFactor, momentum);
break;
case SQRT_DECAY_SGD:
optimiser = SGD.getSqrtDecaySGD(learningRate, momentumFactor, momentum);
break;
case ADA_GRAD:
optimiser = new AdaGrad(learningRate, epsilon);
break;
case ADA_DELTA:
optimiser = new AdaDelta(momentumFactor, epsilon);
break;
Expand All @@ -153,8 +153,9 @@ private void createOptimiser() {
optimiser = new RMSProp(learningRate, momentumFactor, epsilon, decayRate);
break;
default:
// Use default SGD with a constant learning rate.
optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum);
// Use AdaGrad by default, reference issue:
// https://github.com/opensearch-project/ml-commons/issues/3210#issuecomment-2556119802
optimiser = new AdaGrad(learningRate, epsilon);
break;
}
}
Expand Down

0 comments on commit f323141

Please sign in to comment.