diff --git a/psl-core/src/main/java/org/linqs/psl/application/inference/mpe/DistributedDualBCDInference.java b/psl-core/src/main/java/org/linqs/psl/application/inference/mpe/DistributedDualBCDInference.java new file mode 100644 index 000000000..900530c84 --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/application/inference/mpe/DistributedDualBCDInference.java @@ -0,0 +1,46 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.inference.mpe; + +import org.linqs.psl.database.Database; +import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.reasoner.Reasoner; +import org.linqs.psl.reasoner.duallcqp.DistributedDualBCDReasoner; +import org.linqs.psl.reasoner.duallcqp.term.DualLCQPTermStore; +import org.linqs.psl.reasoner.term.TermStore; + +import java.util.List; + +/** + * Use an DualBCD reasoner to perform MPE inference. + */ +public class DistributedDualBCDInference extends MPEInference { + public DistributedDualBCDInference(List rules, Database db) { + super(rules, db); + } + + @Override + protected Reasoner createReasoner() { + return new DistributedDualBCDReasoner(); + } + + @Override + public TermStore createTermStore() { + return new DualLCQPTermStore(database.getAtomStore()); + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DistributedDualBCDReasoner.java b/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DistributedDualBCDReasoner.java new file mode 100644 index 000000000..29a302bda --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DistributedDualBCDReasoner.java @@ -0,0 +1,136 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.reasoner.duallcqp; + +import org.linqs.psl.application.learning.weight.TrainingMap; +import org.linqs.psl.evaluation.EvaluationInstance; +import org.linqs.psl.reasoner.duallcqp.term.DualLCQPObjectiveTerm; +import org.linqs.psl.reasoner.duallcqp.term.DualLCQPTermStore; +import org.linqs.psl.reasoner.term.TermStore; +import org.linqs.psl.util.Logger; +import org.linqs.psl.util.Parallel; + +import java.util.List; + +/** + * A distributed variant of the DualBCDReasoner. + * Note that unlike the DualBCDReasoner, this reasoner does not guarantee an increase in the dual + * objective at every iteration as the solution to the stepsize subproblem may be inexact. + * Practically, this means that this reasoner may not converge to the optimal solution in as + * few iterations as the DualBCDReasoner. + * However, this reasoner does have a lower per iteration runtime. + */ +public class DistributedDualBCDReasoner extends DualBCDReasoner { + private static final org.linqs.psl.util.Logger log = Logger.getLogger(DistributedDualBCDReasoner.class); + + private int blockSize; + private int numTermBlocks; + + public DistributedDualBCDReasoner() { + super(); + + blockSize = -1; + numTermBlocks = -1; + } + + @Override + protected long internalOptimize(DualLCQPTermStore termStore, List evaluations, TrainingMap trainingMap) { + ObjectiveResult primalObjectiveResult = null; + ObjectiveResult oldPrimalObjectiveResult = null; + + long totalTime = 0; + boolean breakDualBCD = false; + int iteration = 1; + while(!breakDualBCD) { + long start = System.currentTimeMillis(); + Parallel.count(numTermBlocks, new BlockUpdateWorker(termStore, blockSize)); + long end = System.currentTimeMillis(); + totalTime += end - start; + + if ((iteration - 1) % computePeriod == 0) { + float variableMovement = primalVariableUpdate(termStore); + + oldPrimalObjectiveResult = primalObjectiveResult; + primalObjectiveResult = parallelComputeObjective(termStore); + ObjectiveResult dualObjectiveResult = parallelComputeDualObjective(termStore); + + breakDualBCD = breakOptimization(iteration, primalObjectiveResult, oldPrimalObjectiveResult, dualObjectiveResult, + maxIterations, runFullIterations, objectiveBreak, objectiveTolerance, + variableMovementBreak, variableMovementTolerance, variableMovement, + primalDualBreak, primalDualTolerance); + + log.trace("Iteration {} -- Primal Objective: {}, Violated Constraints: {}, Dual Objective: {}, Primal-dual gap: {}, Iteration Time: {}, Total Optimization Time: {}.", + iteration, primalObjectiveResult.objective, primalObjectiveResult.violatedConstraints, + dualObjectiveResult.objective, primalObjectiveResult.objective - dualObjectiveResult.objective, + (end - start), totalTime); + + evaluate(termStore, iteration, evaluations, trainingMap); + } + + iteration++; + } + + return totalTime; + } + + @Override + protected void initForOptimization(TermStore termStore) { + super.initForOptimization(termStore); + + blockSize = (int) (termStore.size() / (Parallel.getNumThreads() * 4) + 1); + numTermBlocks = (int) Math.ceil(termStore.size() / (double)blockSize); + } + + private static class BlockUpdateWorker extends Parallel.Worker { + private final DualLCQPTermStore termStore; + private final int blockSize; + + public BlockUpdateWorker(DualLCQPTermStore termStore, int blockSize) { + super(); + + this.termStore = termStore; + this.blockSize = blockSize; + } + + @Override + public Object clone() { + return new BlockUpdateWorker(termStore, blockSize); + } + + @Override + public void work(long blockIndex, Long ignore) { + long numTerms = termStore.size(); + + for (int innerBlockIndex = 0; innerBlockIndex < blockSize; innerBlockIndex++) { + int termIndex = (int) (blockIndex * blockSize + innerBlockIndex); + + if (termIndex >= numTerms) { + break; + } + + DualLCQPObjectiveTerm term = termStore.get(termIndex); + + if (!term.isActive()) { + continue; + } + + dualBlockUpdate(term, termStore); + } + } + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DualBCDReasoner.java b/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DualBCDReasoner.java index 6ce5cebbb..461930517 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DualBCDReasoner.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/duallcqp/DualBCDReasoner.java @@ -48,8 +48,8 @@ public class DualBCDReasoner extends Reasoner { public static final double regularizationParameter = Options.DUAL_LCQP_REGULARIZATION.getDouble(); - private final boolean primalDualBreak; - private final double primalDualTolerance; + protected final boolean primalDualBreak; + protected final double primalDualTolerance; protected final int computePeriod; @@ -73,6 +73,16 @@ public double optimize(TermStore baseTermStore, List evaluations, TrainingMap trainingMap) { log.trace("Starting optimization. Number of connected components: {}.", termStore.getConnectedComponents().size()); long start = System.currentTimeMillis(); @@ -84,11 +94,33 @@ public double optimize(TermStore baseTermStore, List variableMovement) { + variableMovement = Math.abs(atomValues[i] - oldValue); + } + } + + return variableMovement; } /** @@ -120,7 +152,7 @@ protected static float primalVariableComponentUpdate(DualLCQPTermStore termStore return variableMovement; } - private static boolean breakOptimization(int iteration, + protected static boolean breakOptimization(int iteration, ObjectiveResult primalObjectiveResult, ObjectiveResult oldPrimalObjectiveResult, ObjectiveResult dualObjectiveResult, int maxIterations, boolean runFullIterations, @@ -459,13 +491,7 @@ protected static ObjectiveResult computeObjective(TermStore termStore) { ObjectiveResult objectiveResult = super.parallelComputeObjective(termStore); - - log.trace("Unregularized Objective: {}", objectiveResult.objective); - objectiveResult.objective += computePrimalVariableRegularization(termStore); - - log.info("Regularization: {}", computePrimalVariableRegularization(termStore)); - return objectiveResult; } @@ -519,7 +545,36 @@ private static double computePrimalVariableComponentRegularization(TermStore> connectedComponents = ((SimpleTermStore)termStore).getConnectedComponents(); List component = connectedComponents.get(componentIndex); @@ -617,4 +672,47 @@ public void work(long blockIndex, Long ignore) { } } } + + private static class DualTermObjectiveWorker extends Parallel.Worker { + private final DualLCQPTermStore termStore; + private final int blockSize; + private final double[] objectives; + + public DualTermObjectiveWorker(DualLCQPTermStore termStore, double[] objectives, int blockSize) { + super(); + + this.termStore = termStore; + this.objectives = objectives; + this.blockSize = blockSize; + } + + @Override + public Object clone() { + return new DualTermObjectiveWorker(termStore, objectives, blockSize); + } + + @Override + public void work(long blockIndex, Long ignore) { + long numTerms = termStore.size(); + int blockIntIndex = (int)blockIndex; + + objectives[blockIntIndex] = 0.0; + for (int innerBlockIndex = 0; innerBlockIndex < blockSize; innerBlockIndex++) { + int termIndex = blockIntIndex * blockSize + innerBlockIndex; + + if (termIndex >= numTerms) { + break; + } + + DualLCQPObjectiveTerm term = termStore.get(termIndex); + + if (!term.isActive()) { + continue; + } + + objectives[blockIntIndex] += evaluateDualTerm(term, termStore); + objectives[blockIntIndex] += evaluateDualSlackLowerBound(term); + } + } + } } diff --git a/psl-core/src/test/java/org/linqs/psl/application/inference/mpe/DistributedDualBCDInferenceTest.java b/psl-core/src/test/java/org/linqs/psl/application/inference/mpe/DistributedDualBCDInferenceTest.java new file mode 100644 index 000000000..2b8fd07f5 --- /dev/null +++ b/psl-core/src/test/java/org/linqs/psl/application/inference/mpe/DistributedDualBCDInferenceTest.java @@ -0,0 +1,32 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.inference.mpe; + +import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.application.inference.InferenceTest; +import org.linqs.psl.database.Database; +import org.linqs.psl.model.rule.Rule; + +import java.util.List; + +public class DistributedDualBCDInferenceTest extends InferenceTest { + @Override + protected InferenceApplication getInference(List rules, Database db) { + return new DistributedDualBCDInference(rules, db); + } +} diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropyTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropyTest.java index f40c6e9ec..5a92f5e04 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropyTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropyTest.java @@ -17,6 +17,7 @@ */ package org.linqs.psl.application.learning.weight.gradient.minimizer; +import org.junit.Test; import org.linqs.psl.application.inference.mpe.DualBCDInference; import org.linqs.psl.application.learning.weight.WeightLearningApplication; import org.linqs.psl.application.learning.weight.WeightLearningTest; @@ -37,4 +38,18 @@ protected WeightLearningApplication getBaseWLA() { return new BinaryCrossEntropy(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, false); } + + @Test + public void DualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } + + @Test + public void DistributedDualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } } diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredErrorTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredErrorTest.java index 9d3995011..ffaa43898 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredErrorTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredErrorTest.java @@ -17,6 +17,7 @@ */ package org.linqs.psl.application.learning.weight.gradient.minimizer; +import org.junit.Test; import org.linqs.psl.application.inference.mpe.DualBCDInference; import org.linqs.psl.application.learning.weight.WeightLearningApplication; import org.linqs.psl.application.learning.weight.WeightLearningTest; @@ -37,4 +38,18 @@ protected WeightLearningApplication getBaseWLA() { return new SquaredError(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, false); } + + @Test + public void DualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } + + @Test + public void DistributedDualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } } diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java index 623948c98..29ae60798 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/EnergyTest.java @@ -65,4 +65,11 @@ public void DualBCDFriendshipRankTest() { super.friendshipRankTest(); } + + @Test + public void DistributedDualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } }