Skip to content

Commit

Permalink
Validation break.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Sep 1, 2023
1 parent ff2b615 commit c70ff92
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ public static enum GDExtension {
protected List<float[]> batchMAPAtomValueStates;

protected int validationEvaluationComputePeriod;
protected boolean validationBreak;
protected int validationPatience;
protected int lastValidationImprovementEpoch;
protected TermState[] validationMAPTermState;
protected float[] validationMAPAtomValueState;
protected boolean saveBestValidationWeights;
Expand Down Expand Up @@ -152,6 +155,9 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database
bestValidationWeights = null;
currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
validationBreak = Options.WLA_GRADIENT_DESCENT_VALIDATION_BREAK.getBoolean();
validationPatience = Options.WLA_GRADIENT_DESCENT_VALIDATION_PATIENCE.getInt();
lastValidationImprovementEpoch = 0;

if (saveBestValidationWeights && (!this.runValidation)) {
throw new IllegalArgumentException("If saveBestValidationWeights is true, then runValidation must also be true.");
Expand Down Expand Up @@ -307,7 +313,7 @@ protected void doLearn() {
}

if (runValidation && (epoch % validationEvaluationComputePeriod == 0)) {
runValidationEvaluation();
runValidationEvaluation(epoch);
log.debug("Current MAP State Validation Evaluation Metric: {}", currentValidationEvaluationMetric);
}

Expand Down Expand Up @@ -389,7 +395,7 @@ protected void doLearn() {
if (saveBestValidationWeights) {
finalMAPStateValidationEvaluation = bestValidationEvaluationMetric;
} else {
runValidationEvaluation();
runValidationEvaluation(epoch);
finalMAPStateValidationEvaluation = currentValidationEvaluationMetric;
}
log.info("Final MAP State Validation Evaluation Metric: {}", finalMAPStateValidationEvaluation);
Expand Down Expand Up @@ -510,7 +516,7 @@ protected void runMAPEvaluation() {
}
}

protected void runValidationEvaluation() {
protected void runValidationEvaluation(int epoch) {
setValidationModel();

log.trace("Running Validation Inference.");
Expand All @@ -520,6 +526,8 @@ protected void runValidationEvaluation() {
currentValidationEvaluationMetric = evaluation.getNormalizedRepMetric();

if (currentValidationEvaluationMetric > bestValidationEvaluationMetric) {
lastValidationImprovementEpoch = epoch;

bestValidationEvaluationMetric = currentValidationEvaluationMetric;

// Save the best rule weights.
Expand All @@ -539,14 +547,19 @@ protected void runValidationEvaluation() {

protected boolean breakOptimization(int epoch) {
if (epoch >= maxNumSteps) {
log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", maxNumSteps);
log.trace("Breaking Weight Learning. Reached maximum number of epochs: {}", maxNumSteps);
return true;
}

if (runFullIterations) {
return false;
}

if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) {
log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch));
return true;
}

if (movementBreak && MathUtils.equals(parameterMovement, 0.0f, movementTolerance)) {
log.trace("Breaking Weight Learning. Parameter Movement: {} is within tolerance: {}", parameterMovement, movementTolerance);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,21 @@ protected void setBatch(int batch) {
}

@Override
protected boolean breakOptimization(int iteration) {
if (iteration >= maxNumSteps) {
log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", maxNumSteps);
protected boolean breakOptimization(int epoch) {
if (epoch >= maxNumSteps) {
log.trace("Breaking Weight Learning. Reached maximum number of epochs: {}", maxNumSteps);
return true;
}

if (runFullIterations) {
return false;
}

if (validationBreak && (epoch - lastValidationImprovementEpoch) > validationPatience) {
log.trace("Breaking Weight Learning. No improvement in validation evaluation metric for {} epochs.", (epoch - lastValidationImprovementEpoch));
return true;
}

float totalObjectiveDifference = computeTotalObjectiveDifference();
if (totalObjectiveDifference < finalConstraintTolerance) {
log.trace("Breaking Weight Learning. Objective difference {} is less than final constraint tolerance {}.",
Expand Down
45 changes: 13 additions & 32 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -350,22 +350,6 @@ public class Options {
Option.FLAG_POSITIVE
);

public static final Option WLA_GRADIENT_DESCENT_NORM_BREAK = new Option(
"gradientdescent.normbreak",
false,
"When the gradient norm is below the tolerance "
+ " set by gradientdescent.normtolerance, gradient descent weight learning is stopped."
);

public static final Option WLA_GRADIENT_DESCENT_NORM_TOLERANCE = new Option(
"gradientdescent.normtolerance",
1.0e-3f,
"If gradientdescent.runfulliterations=false and gradientdescent.normbreak=true,"
+ " then when the norm of the gradient is below this tolerance "
+ " gradient descent weight learning is stopped.",
Option.FLAG_POSITIVE
);

public static final Option WLA_GRADIENT_DESCENT_NUM_STEPS = new Option(
"gradientdescent.numsteps",
500,
Expand All @@ -387,22 +371,6 @@ public class Options {
Option.FLAG_POSITIVE
);

public static final Option WLA_GRADIENT_DESCENT_OBJECTIVE_BREAK = new Option(
"gradientdescent.objectivebreak",
false,
"When the objective change between iterates is below the tolerance "
+ " set by gradientdescent.objectivetolerance, gradient descent weight learning is stopped."
);

public static final Option WLA_GRADIENT_DESCENT_OBJECTIVE_TOLERANCE = new Option(
"gradientdescent.objectivetolerance",
1.0e-5f,
"If gradientdescent.runfulliterations=false and gradientdescent.objectivebreak=true,"
+ " then when the objective change between iterates is below this tolerance"
+ " gradient descent weight learning is stopped.",
Option.FLAG_POSITIVE
);

public static final Option WLA_GRADIENT_DESCENT_RUN_FULL_ITERATIONS = new Option(
"gradientdescent.runfulliterations",
false,
Expand All @@ -417,6 +385,7 @@ public class Options {
+ " If true, then gradientdescent.runvalidation must be true."
);


public static final Option WLA_GRADIENT_DESCENT_SCALE_STEP = new Option(
"gradientdescent.scalestepsize",
true,
Expand Down Expand Up @@ -444,12 +413,24 @@ public class Options {
"Compute training evaluation every this many iterations of gradient descent weight learning."
);

public static final Option WLA_GRADIENT_DESCENT_VALIDATION_BREAK = new Option(
"gradientdescent.validationbreak",
false,
"Break gradient descent weight learning when the validation evaluation stops improving."
);

public static final Option WLA_GRADIENT_DESCENT_VALIDATION_COMPUTE_PERIOD = new Option(
"gradientdescent.validationcomputeperiod",
1,
"Compute validation evaluation every this many iterations of gradient descent weight learning."
);

public static final Option WLA_GRADIENT_DESCENT_VALIDATION_PATIENCE = new Option(
"gradientdescent.validationpatience",
25,
"Break gradient descent weight learning when the validation evaluation stops improving after this many epochs."
);

public static final Option WLA_GS_POSSIBLE_WEIGHTS = new Option(
"gridsearch.weights",
"0.001:0.01:0.1:1:10",
Expand Down

0 comments on commit c70ff92

Please sign in to comment.