diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 80b2c3dec..6df651a49 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -24,6 +24,7 @@ import org.linqs.psl.database.Database; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.predicate.Predicate; +import org.linqs.psl.model.rule.GroundRule; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.model.rule.WeightedRule; import org.linqs.psl.reasoner.InitialValue; @@ -229,7 +230,10 @@ protected void doLearn() { while (!breakGD) { long start = System.currentTimeMillis(); - log.trace("Model: {}", mutableRules); + log.trace("Model:"); + for (WeightedRule weightedRule: mutableRules) { + log.trace("{}", weightedRule); + } gradientStep(iteration); @@ -464,6 +468,8 @@ protected float atomGradientStep() { deepPredicate.fitDeepPredicate(deepAtomGradient); deepPredicateChange += deepPredicate.predictDeepModel(true); } + + log.trace("Deep Predicate Change: {}", deepPredicateChange); return deepPredicateChange; } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java index 68b91737a..5558cbb30 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java @@ -244,12 +244,12 @@ protected void gradientStep(int iteration) { parameterMovement += internalParameterGradientStep(iteration); parameterMovement += atomGradientStep(); + // Update the penalty coefficients and tolerance. + float totalObjectiveDifference = computeObjectiveDifference(); + if ((iteration > 0) && (parameterMovement < parameterMovementTolerance)) { outerIteration++; - // Update the penalty coefficients and tolerance. - float totalObjectiveDifference = computeObjectiveDifference(); - if (totalObjectiveDifference < constraintTolerance) { if ((totalObjectiveDifference < finalConstraintTolerance) && (parameterMovement < finalParameterMovementTolerance)) { // Learning has converged. @@ -263,11 +263,10 @@ protected void gradientStep(int iteration) { constraintTolerance = (float)(1.0f / Math.pow(squaredPenaltyCoefficient, 0.1)); parameterMovementTolerance = (float)(1.0f / squaredPenaltyCoefficient); } - - log.trace("Outer iteration: {}, Objective Difference: {}, Parameter Movement: {}, Squared Penalty Coefficient: {}, Linear Penalty Coefficient: {}, Constraint Tolerance: {}, parameterMovementTolerance: {}.", - outerIteration, totalObjectiveDifference, parameterMovement, squaredPenaltyCoefficient, linearPenaltyCoefficient, constraintTolerance, parameterMovementTolerance); - } + + log.trace("Outer iteration: {}, Objective Difference: {}, Parameter Movement: {}, Squared Penalty Coefficient: {}, Linear Penalty Coefficient: {}, Constraint Tolerance: {}, parameterMovementTolerance: {}.", + outerIteration, totalObjectiveDifference, parameterMovement, squaredPenaltyCoefficient, linearPenaltyCoefficient, constraintTolerance, parameterMovementTolerance); } @Override @@ -286,6 +285,8 @@ protected float internalParameterGradientStep(int iteration) { augmentedInferenceAtomValueState[proxRuleObservedAtomIndexes[i]] = newProxRuleObservedAtomsValue; } + log.trace("Proximity Rule Observed Atoms Value Movement: {}.", proxRuleObservedAtomsValueMovement); + return proxRuleObservedAtomsValueMovement; } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java index 81bbdda2c..eb66d9410 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java @@ -20,10 +20,12 @@ import org.linqs.psl.application.learning.weight.gradient.GradientDescent; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; +import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.atom.ObservedAtom; import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.reasoner.term.TermState; +import org.linqs.psl.util.Logger; import java.util.Arrays; import java.util.List; @@ -39,6 +41,8 @@ * computing the incompatibility of the latent variable inference problem solution are provided in this class. */ public abstract class OptimalValue extends GradientDescent { + private static final Logger log = Logger.getLogger(GradientDescent.class); + protected float[] latentInferenceIncompatibility; protected TermState[] latentInferenceTermState; protected float[] latentInferenceAtomValueState; @@ -81,6 +85,10 @@ protected void computeLatentInferenceIncompatibility() { computeCurrentIncompatibility(latentInferenceIncompatibility); trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), rvLatentAtomGradient, deepLatentAtomGradient); + for (int i = 0; i < mutableRules.size(); i++) { + log.trace("Rule: {} , Latent inference incompatibility: {}", mutableRules.get(i), latentInferenceIncompatibility[i]); + } + unfixLabeledRandomVariables(); } diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/sgd/term/SGDObjectiveTerm.java b/psl-core/src/main/java/org/linqs/psl/reasoner/sgd/term/SGDObjectiveTerm.java index 8f2cbd265..3543cbba8 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/sgd/term/SGDObjectiveTerm.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/sgd/term/SGDObjectiveTerm.java @@ -87,54 +87,4 @@ public void read(ByteBuffer fixedBuffer) { atomIndexes[i] = fixedBuffer.getInt(); } } - - @Override - public String toString() { - return toString(null); - } - - public String toString(AtomStore atomStore) { - // weight * [max(coeffs^T * x - constant, 0.0)]^2 - - StringBuilder builder = new StringBuilder(); - - builder.append(getWeight()); - builder.append(" * "); - - if (hinge) { - builder.append("max(0.0, "); - } else { - builder.append("("); - } - - for (int i = 0; i < size; i++) { - builder.append("("); - builder.append(coefficients[i]); - - if (atomStore == null) { - builder.append(" * )"); - } else { - builder.append(" * "); - builder.append(atomStore.getAtomValue(atomIndexes[i])); - builder.append(")"); - } - - if (i != size - 1) { - builder.append(" + "); - } - } - - builder.append(" - "); - builder.append(constant); - - builder.append(")"); - - if (squared) { - builder.append(" ^2"); - } - - return builder.toString(); - } } diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java b/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java index ca0ae59ac..effcc7d01 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java @@ -17,6 +17,7 @@ */ package org.linqs.psl.reasoner.term; +import org.linqs.psl.database.AtomStore; import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.model.rule.WeightedRule; @@ -329,4 +330,54 @@ public TermState saveState() { public void saveState(TermState termState) { // Pass. } + + @Override + public String toString() { + return toString(null); + } + + public String toString(AtomStore atomStore) { + // weight * [max(coeffs^T * x - constant, 0.0)]^2 + + StringBuilder builder = new StringBuilder(); + + builder.append(getWeight()); + builder.append(" * "); + + if (hinge) { + builder.append("max(0.0, "); + } else { + builder.append("("); + } + + for (int i = 0; i < size; i++) { + builder.append("("); + builder.append(coefficients[i]); + + if (atomStore == null) { + builder.append(" * )"); + } else { + builder.append(" * "); + builder.append(atomStore.getAtomValue(atomIndexes[i])); + builder.append(")"); + } + + if (i != size - 1) { + builder.append(" + "); + } + } + + builder.append(" - "); + builder.append(constant); + + builder.append(")"); + + if (squared) { + builder.append(" ^2"); + } + + return builder.toString(); + } }