Skip to content

Commit

Permalink
Improved logging.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Jul 12, 2023
1 parent 2107d70 commit 6d12e64
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.linqs.psl.database.Database;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
Expand Down Expand Up @@ -229,7 +230,10 @@ protected void doLearn() {
while (!breakGD) {
long start = System.currentTimeMillis();

log.trace("Model: {}", mutableRules);
log.trace("Model:");
for (WeightedRule weightedRule: mutableRules) {
log.trace("{}", weightedRule);
}

gradientStep(iteration);

Expand Down Expand Up @@ -464,6 +468,8 @@ protected float atomGradientStep() {
deepPredicate.fitDeepPredicate(deepAtomGradient);
deepPredicateChange += deepPredicate.predictDeepModel(true);
}

log.trace("Deep Predicate Change: {}", deepPredicateChange);
return deepPredicateChange;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ protected void gradientStep(int iteration) {
parameterMovement += internalParameterGradientStep(iteration);
parameterMovement += atomGradientStep();

// Update the penalty coefficients and tolerance.
float totalObjectiveDifference = computeObjectiveDifference();

if ((iteration > 0) && (parameterMovement < parameterMovementTolerance)) {
outerIteration++;

// Update the penalty coefficients and tolerance.
float totalObjectiveDifference = computeObjectiveDifference();

if (totalObjectiveDifference < constraintTolerance) {
if ((totalObjectiveDifference < finalConstraintTolerance) && (parameterMovement < finalParameterMovementTolerance)) {
// Learning has converged.
Expand All @@ -263,11 +263,10 @@ protected void gradientStep(int iteration) {
constraintTolerance = (float)(1.0f / Math.pow(squaredPenaltyCoefficient, 0.1));
parameterMovementTolerance = (float)(1.0f / squaredPenaltyCoefficient);
}

log.trace("Outer iteration: {}, Objective Difference: {}, Parameter Movement: {}, Squared Penalty Coefficient: {}, Linear Penalty Coefficient: {}, Constraint Tolerance: {}, parameterMovementTolerance: {}.",
outerIteration, totalObjectiveDifference, parameterMovement, squaredPenaltyCoefficient, linearPenaltyCoefficient, constraintTolerance, parameterMovementTolerance);

}

log.trace("Outer iteration: {}, Objective Difference: {}, Parameter Movement: {}, Squared Penalty Coefficient: {}, Linear Penalty Coefficient: {}, Constraint Tolerance: {}, parameterMovementTolerance: {}.",
outerIteration, totalObjectiveDifference, parameterMovement, squaredPenaltyCoefficient, linearPenaltyCoefficient, constraintTolerance, parameterMovementTolerance);
}

@Override
Expand All @@ -286,6 +285,8 @@ protected float internalParameterGradientStep(int iteration) {
augmentedInferenceAtomValueState[proxRuleObservedAtomIndexes[i]] = newProxRuleObservedAtomsValue;
}

log.trace("Proximity Rule Observed Atoms Value Movement: {}.", proxRuleObservedAtomsValueMovement);

return proxRuleObservedAtomsValueMovement;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;

import java.util.Arrays;
import java.util.List;
Expand All @@ -39,6 +41,8 @@
* computing the incompatibility of the latent variable inference problem solution are provided in this class.
*/
public abstract class OptimalValue extends GradientDescent {
private static final Logger log = Logger.getLogger(GradientDescent.class);

protected float[] latentInferenceIncompatibility;
protected TermState[] latentInferenceTermState;
protected float[] latentInferenceAtomValueState;
Expand Down Expand Up @@ -81,6 +85,10 @@ protected void computeLatentInferenceIncompatibility() {
computeCurrentIncompatibility(latentInferenceIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), rvLatentAtomGradient, deepLatentAtomGradient);

for (int i = 0; i < mutableRules.size(); i++) {
log.trace("Rule: {} , Latent inference incompatibility: {}", mutableRules.get(i), latentInferenceIncompatibility[i]);
}

unfixLabeledRandomVariables();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,54 +87,4 @@ public void read(ByteBuffer fixedBuffer) {
atomIndexes[i] = fixedBuffer.getInt();
}
}

@Override
public String toString() {
return toString(null);
}

public String toString(AtomStore atomStore) {
// weight * [max(coeffs^T * x - constant, 0.0)]^2

StringBuilder builder = new StringBuilder();

builder.append(getWeight());
builder.append(" * ");

if (hinge) {
builder.append("max(0.0, ");
} else {
builder.append("(");
}

for (int i = 0; i < size; i++) {
builder.append("(");
builder.append(coefficients[i]);

if (atomStore == null) {
builder.append(" * <index:");
builder.append(atomIndexes[i]);
builder.append(">)");
} else {
builder.append(" * ");
builder.append(atomStore.getAtomValue(atomIndexes[i]));
builder.append(")");
}

if (i != size - 1) {
builder.append(" + ");
}
}

builder.append(" - ");
builder.append(constant);

builder.append(")");

if (squared) {
builder.append(" ^2");
}

return builder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.linqs.psl.reasoner.term;

import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
Expand Down Expand Up @@ -329,4 +330,54 @@ public TermState saveState() {
public void saveState(TermState termState) {
// Pass.
}

@Override
public String toString() {
return toString(null);
}

public String toString(AtomStore atomStore) {
// weight * [max(coeffs^T * x - constant, 0.0)]^2

StringBuilder builder = new StringBuilder();

builder.append(getWeight());
builder.append(" * ");

if (hinge) {
builder.append("max(0.0, ");
} else {
builder.append("(");
}

for (int i = 0; i < size; i++) {
builder.append("(");
builder.append(coefficients[i]);

if (atomStore == null) {
builder.append(" * <index:");
builder.append(atomIndexes[i]);
builder.append(">)");
} else {
builder.append(" * ");
builder.append(atomStore.getAtomValue(atomIndexes[i]));
builder.append(")");
}

if (i != size - 1) {
builder.append(" + ");
}
}

builder.append(" - ");
builder.append(constant);

builder.append(")");

if (squared) {
builder.append(" ^2");
}

return builder.toString();
}
}

0 comments on commit 6d12e64

Please sign in to comment.