Skip to content

Commit

Permalink
Distributed dual bcd
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Sep 7, 2023
1 parent 5954f6f commit 37fa46d
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.inference.mpe;

import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.duallcqp.DistributedDualBCDReasoner;
import org.linqs.psl.reasoner.duallcqp.term.DualLCQPTermStore;
import org.linqs.psl.reasoner.term.TermStore;

import java.util.List;

/**
* Use an DualBCD reasoner to perform MPE inference.
*/
public class DistributedDualBCDInference extends MPEInference {
public DistributedDualBCDInference(List<Rule> rules, Database db) {
super(rules, db);
}

@Override
protected Reasoner createReasoner() {
return new DistributedDualBCDReasoner();
}

@Override
public TermStore createTermStore() {
return new DualLCQPTermStore(database.getAtomStore());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.reasoner.duallcqp;

import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.reasoner.duallcqp.term.DualLCQPObjectiveTerm;
import org.linqs.psl.reasoner.duallcqp.term.DualLCQPTermStore;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.Parallel;

import java.util.List;

/**
* A distributed variant of the DualBCDReasoner.
* Note that unlike the DualBCDReasoner, this reasoner does not guarantee an increase in the dual
* objective at every iteration as the solution to the stepsize subproblem may be inexact.
* Practically, this means that this reasoner may not converge to the optimal solution in as
* few iterations as the DualBCDReasoner.
* However, this reasoner does have a lower per iteration runtime.
*/
public class DistributedDualBCDReasoner extends DualBCDReasoner {
private static final org.linqs.psl.util.Logger log = Logger.getLogger(DistributedDualBCDReasoner.class);

private int blockSize;
private int numTermBlocks;

public DistributedDualBCDReasoner() {
super();

blockSize = -1;
numTermBlocks = -1;
}

@Override
protected long internalOptimize(DualLCQPTermStore termStore, List<EvaluationInstance> evaluations, TrainingMap trainingMap) {
ObjectiveResult primalObjectiveResult = null;
ObjectiveResult oldPrimalObjectiveResult = null;

long totalTime = 0;
boolean breakDualBCD = false;
int iteration = 1;
while(!breakDualBCD) {
long start = System.currentTimeMillis();
Parallel.count(numTermBlocks, new BlockUpdateWorker(termStore, blockSize));
long end = System.currentTimeMillis();
totalTime += end - start;

if ((iteration - 1) % computePeriod == 0) {
float variableMovement = primalVariableUpdate(termStore);

oldPrimalObjectiveResult = primalObjectiveResult;
primalObjectiveResult = parallelComputeObjective(termStore);
ObjectiveResult dualObjectiveResult = parallelComputeDualObjective(termStore);

breakDualBCD = breakOptimization(iteration, primalObjectiveResult, oldPrimalObjectiveResult, dualObjectiveResult,
maxIterations, runFullIterations, objectiveBreak, objectiveTolerance,
variableMovementBreak, variableMovementTolerance, variableMovement,
primalDualBreak, primalDualTolerance);

log.trace("Iteration {} -- Primal Objective: {}, Violated Constraints: {}, Dual Objective: {}, Primal-dual gap: {}, Iteration Time: {}, Total Optimization Time: {}.",
iteration, primalObjectiveResult.objective, primalObjectiveResult.violatedConstraints,
dualObjectiveResult.objective, primalObjectiveResult.objective - dualObjectiveResult.objective,
(end - start), totalTime);

evaluate(termStore, iteration, evaluations, trainingMap);
}

iteration++;
}

return totalTime;
}

@Override
protected void initForOptimization(TermStore<DualLCQPObjectiveTerm> termStore) {
super.initForOptimization(termStore);

blockSize = (int) (termStore.size() / (Parallel.getNumThreads() * 4) + 1);
numTermBlocks = (int) Math.ceil(termStore.size() / (double)blockSize);
}

private static class BlockUpdateWorker extends Parallel.Worker<Long> {
private final DualLCQPTermStore termStore;
private final int blockSize;

public BlockUpdateWorker(DualLCQPTermStore termStore, int blockSize) {
super();

this.termStore = termStore;
this.blockSize = blockSize;
}

@Override
public Object clone() {
return new BlockUpdateWorker(termStore, blockSize);
}

@Override
public void work(long blockIndex, Long ignore) {
long numTerms = termStore.size();

for (int innerBlockIndex = 0; innerBlockIndex < blockSize; innerBlockIndex++) {
int termIndex = (int) (blockIndex * blockSize + innerBlockIndex);

if (termIndex >= numTerms) {
break;
}

DualLCQPObjectiveTerm term = termStore.get(termIndex);

if (!term.isActive()) {
continue;
}

dualBlockUpdate(term, termStore);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public class DualBCDReasoner extends Reasoner<DualLCQPObjectiveTerm> {

public static final double regularizationParameter = Options.DUAL_LCQP_REGULARIZATION.getDouble();

private final boolean primalDualBreak;
private final double primalDualTolerance;
protected final boolean primalDualBreak;
protected final double primalDualTolerance;

protected final int computePeriod;

Expand All @@ -73,6 +73,16 @@ public double optimize(TermStore<DualLCQPObjectiveTerm> baseTermStore, List<Eval
termStore.initForOptimization();
initForOptimization(termStore);

long totalTime = internalOptimize(termStore, evaluations, trainingMap);

optimizationComplete(termStore, parallelComputeObjective(termStore), totalTime);

// Return the un-regularized quantification of the objective for consistency with
// weight learning objectives and test assertions.
return super.parallelComputeObjective(termStore).objective;
}

protected long internalOptimize(DualLCQPTermStore termStore, List<EvaluationInstance> evaluations, TrainingMap trainingMap) {
log.trace("Starting optimization. Number of connected components: {}.", termStore.getConnectedComponents().size());

long start = System.currentTimeMillis();
Expand All @@ -84,11 +94,33 @@ public double optimize(TermStore<DualLCQPObjectiveTerm> baseTermStore, List<Eval

evaluate(termStore, 1, evaluations, trainingMap);

optimizationComplete(termStore, parallelComputeObjective(termStore), totalTime);
return totalTime;
}

// Return the un-regularized quantification of the objective for consistency with
// weight learning objectives and test assertions.
return super.parallelComputeObjective(termStore).objective;
/**
* Map the current setting of the dual variables to primal variables.
*/
protected float primalVariableUpdate(DualLCQPTermStore termStore) {
AtomStore atomStore = termStore.getAtomStore();
GroundAtom[] atoms = atomStore.getAtoms();
float[] atomValues = atomStore.getAtomValues();

float variableMovement = 0.0f;
for (int i = 0; i < atomStore.size(); i ++) {
if (atoms[i].isFixed()) {
continue;
}

float oldValue = atomValues[i];
atomValues[i] = termStore.getDualLCQPAtom(i).getPrimal(regularizationParameter);

// Update the variable movement to be the largest absolute change in any variable.
if (Math.abs(atomValues[i] - oldValue) > variableMovement) {
variableMovement = Math.abs(atomValues[i] - oldValue);
}
}

return variableMovement;
}

/**
Expand Down Expand Up @@ -120,7 +152,7 @@ protected static float primalVariableComponentUpdate(DualLCQPTermStore termStore
return variableMovement;
}

private static boolean breakOptimization(int iteration,
protected static boolean breakOptimization(int iteration,
ObjectiveResult primalObjectiveResult, ObjectiveResult oldPrimalObjectiveResult,
ObjectiveResult dualObjectiveResult,
int maxIterations, boolean runFullIterations,
Expand Down Expand Up @@ -459,13 +491,7 @@ protected static ObjectiveResult computeObjective(TermStore<? extends ReasonerTe
@Override
protected ObjectiveResult parallelComputeObjective(TermStore<DualLCQPObjectiveTerm> termStore) {
ObjectiveResult objectiveResult = super.parallelComputeObjective(termStore);

log.trace("Unregularized Objective: {}", objectiveResult.objective);

objectiveResult.objective += computePrimalVariableRegularization(termStore);

log.info("Regularization: {}", computePrimalVariableRegularization(termStore));

return objectiveResult;
}

Expand Down Expand Up @@ -519,7 +545,36 @@ private static double computePrimalVariableComponentRegularization(TermStore<Dua
return atomValueRegularization;
}

private static ObjectiveResult computeComponentDualObjective(DualLCQPTermStore termStore, int componentIndex) {
protected ObjectiveResult parallelComputeDualObjective(DualLCQPTermStore termStore) {
int blockSize = (int)(termStore.size() / (Parallel.getNumThreads() * 4) + 1);
int numTermBlocks = (int)Math.ceil(termStore.size() / (float)blockSize);

double[] workerObjectives = new double[numTermBlocks];

Parallel.count(numTermBlocks, new DualTermObjectiveWorker(termStore, workerObjectives, blockSize));

double objectiveValue = 0.0;
for(int i = 0; i < numTermBlocks; i++) {
objectiveValue += workerObjectives[i];
}

// The upper and lower bound constraints on the atoms are not stored in the term store,
// so we need to compute their objective and gradient contribution here.
GroundAtom[] atoms = termStore.getAtomStore().getAtoms();
DualLCQPAtom[] dualLCQPAtoms = termStore.getDualLCQPAtoms();
for(int i = 0; i < dualLCQPAtoms.length; i++) {
if (atoms[i] == null || atoms[i].isFixed()) {
continue;
}

objectiveValue += dualLCQPAtoms[i].getLowerBoundObjective(regularizationParameter);
objectiveValue += dualLCQPAtoms[i].getUpperBoundObjective(regularizationParameter);
}

return new ObjectiveResult((float)(-0.5 * objectiveValue), 0);
}

protected static ObjectiveResult computeComponentDualObjective(DualLCQPTermStore termStore, int componentIndex) {
Map<Integer, List<DualLCQPObjectiveTerm>> connectedComponents = ((SimpleTermStore<DualLCQPObjectiveTerm>)termStore).getConnectedComponents();
List<DualLCQPObjectiveTerm> component = connectedComponents.get(componentIndex);

Expand Down Expand Up @@ -617,4 +672,47 @@ public void work(long blockIndex, Long ignore) {
}
}
}

private static class DualTermObjectiveWorker extends Parallel.Worker<Long> {
private final DualLCQPTermStore termStore;
private final int blockSize;
private final double[] objectives;

public DualTermObjectiveWorker(DualLCQPTermStore termStore, double[] objectives, int blockSize) {
super();

this.termStore = termStore;
this.objectives = objectives;
this.blockSize = blockSize;
}

@Override
public Object clone() {
return new DualTermObjectiveWorker(termStore, objectives, blockSize);
}

@Override
public void work(long blockIndex, Long ignore) {
long numTerms = termStore.size();
int blockIntIndex = (int)blockIndex;

objectives[blockIntIndex] = 0.0;
for (int innerBlockIndex = 0; innerBlockIndex < blockSize; innerBlockIndex++) {
int termIndex = blockIntIndex * blockSize + innerBlockIndex;

if (termIndex >= numTerms) {
break;
}

DualLCQPObjectiveTerm term = termStore.get(termIndex);

if (!term.isActive()) {
continue;
}

objectives[blockIntIndex] += evaluateDualTerm(term, termStore);
objectives[blockIntIndex] += evaluateDualSlackLowerBound(term);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.inference.mpe;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.inference.InferenceTest;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;

import java.util.List;

public class DistributedDualBCDInferenceTest extends InferenceTest {
@Override
protected InferenceApplication getInference(List<Rule> rules, Database db) {
return new DistributedDualBCDInference(rules, db);
}
}
Loading

0 comments on commit 37fa46d

Please sign in to comment.