diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 0d8701899..f0cfc258f 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -80,8 +80,7 @@ public static enum GDExtension { protected List batchMAPAtomValueStates; - protected int numBatches; - protected List> batchTermStores; + protected LearningBatchGenerator batchGenerator; protected TermState[] validationMAPTermState; protected float[] validationMAPAtomValueState; @@ -132,10 +131,9 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database trainFullMAPAtomValueState = null; trainMAPAtomValueState = null; - numBatches = Options.WLA_GRADIENT_DESCENT_NUM_BATCHES.getInt(); - batchTermStores = new ArrayList>(numBatches); - batchMAPTermStates = new ArrayList<>(numBatches); - batchMAPAtomValueStates = new ArrayList<>(numBatches); + batchGenerator = null; + batchMAPTermStates = new ArrayList(); + batchMAPAtomValueStates = new ArrayList(); validationMAPTermState = null; validationMAPAtomValueState = null; @@ -210,27 +208,12 @@ protected void postInitGroundModel() { } } - // ToDo(Charles): Create non-trivial batches. Currently, each batch is a copy of the full term store. - for (int i = 0; i < numBatches; i++) { - // Create a new term store and atom store for each batch. - AtomStore batchAtomStore = new AtomStore(); - for (GroundAtom atom : trainFullTermStore.getAtomStore()) { - // Make a copy of the atom so that the batch atom store can be modified without affecting the full atom store. - batchAtomStore.addAtom(atom.copy()); - - } - - SimpleTermStore batchTermStore = (SimpleTermStore)trainInferenceApplication.createTermStore(); - for (ReasonerTerm term : trainFullTermStore) { - ReasonerTerm batchTerm = term.copy(); - - batchTermStore.add(batchTerm); - } - batchTermStore.setAtomStore(batchAtomStore); - batchTermStores.add(batchTermStore); + batchGenerator = new RandomNodeBatchGenerator(trainInferenceApplication); + batchGenerator.generateBatches(); + for (SimpleTermStore batchTermStore : batchGenerator.getBatchTermStores()) { batchMAPTermStates.add(batchTermStore.saveState()); - batchMAPAtomValueStates.add(Arrays.copyOf(batchAtomStore.getAtomValues(), batchAtomStore.getAtomValues().length)); + batchMAPAtomValueStates.add(Arrays.copyOf(batchTermStore.getAtomStore().getAtomValues(), batchTermStore.getAtomStore().getAtomValues().length)); } } @@ -322,7 +305,7 @@ protected void doLearn() { log.debug("MAP State Best Validation Evaluation Metric: {}", bestValidationEvaluationMetric); } - for (int i = 0; i < numBatches; i++) { + for (int i = 0; i < batchGenerator.getNumBatches(); i++) { setBatch(i); computeIterationStatistics(); @@ -389,13 +372,15 @@ protected void doLearn() { } protected void setBatch(int batch) { - trainInferenceApplication.setTermStore(batchTermStores.get(batch)); + SimpleTermStore batchTermStore = batchGenerator.getBatchTermStore(batch); + + trainInferenceApplication.setTermStore(batchTermStore); trainMAPTermState = batchMAPTermStates.get(batch); trainMAPAtomValueState = batchMAPAtomValueStates.get(batch); // Set the deep predicate atom store and predict with the deep predicates again to ensure predictions are aligned with the batch. for (DeepPredicate deepPredicate : deepPredicates) { - deepPredicate.getDeepModel().setAtomStore(batchTermStores.get(batch).getAtomStore(), true); + deepPredicate.getDeepModel().setAtomStore(batchTermStore.getAtomStore(), true); deepPredicate.predictDeepModel(true); } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/LearningBatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/LearningBatchGenerator.java new file mode 100644 index 000000000..c99899dcf --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/LearningBatchGenerator.java @@ -0,0 +1,73 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.learning.weight.gradient; + +import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.reasoner.term.ReasonerTerm; +import org.linqs.psl.reasoner.term.SimpleTermStore; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A class for generating batches of data for learning. + * A batch in this case is a set of atoms and terms defining a subgraph of the complete factor graph. + */ +public abstract class LearningBatchGenerator { + protected SimpleTermStore fullTermStore; + protected List> batchTermStores; + + protected InferenceApplication inferenceApplication; + + + public LearningBatchGenerator(InferenceApplication inferenceApplication) { + assert inferenceApplication.getTermStore() instanceof SimpleTermStore; + this.inferenceApplication = inferenceApplication; + this.fullTermStore = (SimpleTermStore)inferenceApplication.getTermStore(); + + batchTermStores = new ArrayList>(); + } + + public void shuffle() { + Collections.shuffle(batchTermStores); + } + + public int getNumBatches() { + return batchTermStores.size(); + } + + public List> getBatchTermStores() { + return batchTermStores; + } + + public SimpleTermStore getBatchTermStore(int index) { + return batchTermStores.get(index); + } + + public abstract void generateBatches(); + + public void clear() { + for (SimpleTermStore termStore : batchTermStores) { + termStore.getAtomStore().close(); + termStore.clear(); + } + + batchTermStores.clear(); + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/RandomNodeBatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/RandomNodeBatchGenerator.java new file mode 100644 index 000000000..3b680b29d --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/RandomNodeBatchGenerator.java @@ -0,0 +1,120 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.learning.weight.gradient; + +import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.config.Options; +import org.linqs.psl.database.AtomStore; +import org.linqs.psl.model.atom.GroundAtom; +import org.linqs.psl.model.atom.RandomVariableAtom; +import org.linqs.psl.reasoner.term.ReasonerTerm; +import org.linqs.psl.reasoner.term.SimpleTermStore; + +import java.util.*; + +public class RandomNodeBatchGenerator extends LearningBatchGenerator { + private final int numBatches; + private final int batchSize; + private final int bfsDepth; + + public RandomNodeBatchGenerator(InferenceApplication inferenceApplication) { + super(inferenceApplication); + + numBatches = Options.RANDOM_NODE_BATCH_GENERATOR_NUM_BATCHES.getInt(); + batchSize = (int) Math.ceil(((float) fullTermStore.getAtomStore().getNumRVAtoms()) / numBatches); + bfsDepth = Options.RANDOM_NODE_BATCH_GENERATOR_BFS_DEPTH.getInt(); + } + + @Override + public void generateBatches() { + // Clear out any old batches. + clear(); + + // Randomly sample batchSize number of random variable atoms from the full atom store and create a new term store and atom store for each batch. + ArrayList allAtoms = new ArrayList(Arrays.asList(Arrays.copyOf(fullTermStore.getAtomStore().getAtoms(), fullTermStore.getAtomStore().size()))); + Collections.shuffle(allAtoms); + + for (int i = 0; i < numBatches; i++) { + AtomStore batchAtomStore = new AtomStore(); + SimpleTermStore batchTermStore = (SimpleTermStore)inferenceApplication.createTermStore(); + batchTermStore.setAtomStore(batchAtomStore); + batchTermStores.add(batchTermStore); + + HashSet visitedTerms = new HashSet<>(); + HashSet visitedAtoms = new HashSet<>(); + + // The last batch may be smaller than the rest. + while ((batchAtomStore.size() < batchSize) && !allAtoms.isEmpty()) { + GroundAtom originalAtom = allAtoms.remove(0); + + if (originalAtom.isFixed()) { + continue; + } + + // Make a copy of the atom so that the batch atom store can be modified without affecting the full atom store. + RandomVariableAtom newBatchRVAtom = (RandomVariableAtom)originalAtom.copy(); + newBatchRVAtom.clearTerms(); + + if (!visitedAtoms.contains(originalAtom)) { + batchAtomStore.addAtom(newBatchRVAtom); + } + + // Perform a bfs on the factor graph starting from the sampled atoms to obtain batch terms. + ArrayList bfsCurrentDepthQueue = new ArrayList(originalAtom.getTerms()); + for (int depth = 0; depth < bfsDepth; depth++) { + ArrayList bfsNextDepthQueue = new ArrayList(); + + for (ReasonerTerm term : bfsCurrentDepthQueue) { + if (visitedTerms.contains(term)) { + continue; + } + + visitedTerms.add(term); + + int[] originalAtomIndexes = term.getAtomIndexes(); + int[] newAtomIndexes = new int[term.getAtomIndexes().length]; + for (int j = 0 ; j < term.size(); j ++) { + int atomIndex = originalAtomIndexes[j]; + GroundAtom atom = fullTermStore.getAtomStore().getAtom(atomIndex); + + if (visitedAtoms.contains(atom)) { + newAtomIndexes[j] = batchAtomStore.getAtomIndex(atom); + continue; + } + + visitedAtoms.add(atom); + + GroundAtom newBatchAtom = atom.copy(); + newBatchAtom.clearTerms(); + batchAtomStore.addAtom(newBatchAtom); + newAtomIndexes[j] = batchAtomStore.getAtomIndex(atom); + + bfsNextDepthQueue.addAll(atom.getTerms()); + } + + ReasonerTerm newBatchTerm = term.copy(); + newBatchTerm.setAtomIndexes(newAtomIndexes); + batchTermStore.add(newBatchTerm); + } + + bfsCurrentDepthQueue = bfsNextDepthQueue; + } + } + } + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java index 5cd38f5d5..9f4ee70fc 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.java @@ -18,13 +18,10 @@ package org.linqs.psl.application.learning.weight.gradient.optimalvalue; import org.linqs.psl.application.learning.weight.gradient.GradientDescent; -import org.linqs.psl.application.learning.weight.gradient.minimizer.Minimizer; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; -import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.atom.ObservedAtom; import org.linqs.psl.model.atom.RandomVariableAtom; -import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.reasoner.term.SimpleTermStore; @@ -69,8 +66,8 @@ public OptimalValue(List rules, Database trainTargetDatabase, Database tra latentFullInferenceAtomValueState = null; latentInferenceAtomValueState = null; - batchLatentInferenceTermStates = new ArrayList<>(numBatches); - batchlatentInferenceAtomValueState = new ArrayList<>(numBatches); + batchLatentInferenceTermStates = new ArrayList(); + batchlatentInferenceAtomValueState = new ArrayList(); } @Override @@ -88,8 +85,8 @@ protected void postInitGroundModel() { rvLatentAtomGradient = new float[atomValues.length]; deepLatentAtomGradient = new float[atomValues.length]; - for (int i = 0; i < numBatches; i++) { - SimpleTermStore batchTermStore = batchTermStores.get(i); + for (int i = 0; i < batchGenerator.getNumBatches(); i++) { + SimpleTermStore batchTermStore = batchGenerator.getBatchTermStore(i); batchLatentInferenceTermStates.add(batchTermStore.saveState()); batchlatentInferenceAtomValueState.add(Arrays.copyOf(batchTermStore.getAtomStore().getAtomValues(), batchTermStore.getAtomStore().getAtomValues().length)); } @@ -140,6 +137,12 @@ protected void fixLabeledRandomVariables() { ObservedAtom observedAtom = entry.getValue(); int atomIndex = atomStore.getAtomIndex(randomVariableAtom); + + if (atomIndex == -1) { + // This atom is not in the current batch. + continue; + } + atomStore.getAtoms()[atomIndex] = observedAtom; atomStore.getAtomValues()[atomIndex] = observedAtom.getValue(); latentInferenceAtomValueState[atomIndex] = observedAtom.getValue(); @@ -159,6 +162,12 @@ protected void unfixLabeledRandomVariables() { RandomVariableAtom randomVariableAtom = entry.getKey(); int atomIndex = atomStore.getAtomIndex(randomVariableAtom); + + if (atomIndex == -1) { + // This atom is not in the current batch. + continue; + } + atomStore.getAtoms()[atomIndex] = randomVariableAtom; } } diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index a7c9aff21..3057803b9 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -363,13 +363,6 @@ public class Options { Option.FLAG_POSITIVE ); - public static final Option WLA_GRADIENT_DESCENT_NUM_BATCHES = new Option( - "gradientdescent.numbatches", - 1, - "The number of batches to use for gradient descent weight learning." - + " The default is 1, which means that all training examples are used in each iteration." - ); - public static final Option WLA_GRADIENT_DESCENT_NUM_STEPS = new Option( "gradientdescent.numsteps", 500, @@ -708,6 +701,18 @@ public class Options { Option.FLAG_POSITIVE ); + public static final Option RANDOM_NODE_BATCH_GENERATOR_NUM_BATCHES = new Option( + "randomnodebatchgenerator.numbatches", + 10, + "The number of batches to sample for random node batch generator." + ); + + public static final Option RANDOM_NODE_BATCH_GENERATOR_BFS_DEPTH = new Option( + "randomnodebatchgenerator.bfsdepth", + 1, + "The depth of the factor graph bfs search for random node batch generator." + ); + public static final Option EVAL_AUC_REPRESENTATIVE = new Option( "aucevaluator.representative", AUCEvaluator.RepresentativeMetric.AUROC.toString(), diff --git a/psl-core/src/main/java/org/linqs/psl/database/AtomStore.java b/psl-core/src/main/java/org/linqs/psl/database/AtomStore.java index 008e20a68..20d1fa193 100644 --- a/psl-core/src/main/java/org/linqs/psl/database/AtomStore.java +++ b/psl-core/src/main/java/org/linqs/psl/database/AtomStore.java @@ -39,6 +39,7 @@ public class AtomStore implements Iterable { public static final int MIN_ALLOCATION = 100; protected int numAtoms; + protected int numRVAtoms; protected float[] atomValues; protected GroundAtom[] atoms; protected int maxRVAIndex; @@ -48,6 +49,7 @@ public AtomStore() { log.debug("Initializing AtomStore."); numAtoms = 0; + numRVAtoms = 0; maxRVAIndex = -1; double overallocationFactor = Options.ATOM_STORE_OVERALLOCATION_FACTOR.getDouble(); @@ -62,6 +64,10 @@ public int size() { return numAtoms; } + public int getNumRVAtoms() { + return numRVAtoms; + } + public int getMaxRVAIndex() { return maxRVAIndex; } @@ -186,6 +192,7 @@ protected synchronized void addAtomInternal(GroundAtom atom) { if (atom instanceof RandomVariableAtom) { maxRVAIndex = numAtoms; + numRVAtoms++; } numAtoms++; @@ -193,6 +200,7 @@ protected synchronized void addAtomInternal(GroundAtom atom) { public void close() { numAtoms = 0; + numRVAtoms = 0; atomValues = null; atoms = null; maxRVAIndex = -1; diff --git a/psl-core/src/main/java/org/linqs/psl/model/atom/GroundAtom.java b/psl-core/src/main/java/org/linqs/psl/model/atom/GroundAtom.java index 72a8a8f10..c0a3ac8e6 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/atom/GroundAtom.java +++ b/psl-core/src/main/java/org/linqs/psl/model/atom/GroundAtom.java @@ -21,8 +21,12 @@ import org.linqs.psl.model.term.Constant; import org.linqs.psl.model.term.VariableTypeMap; import org.linqs.psl.reasoner.function.FunctionTerm; +import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.util.StringUtils; +import java.util.ArrayList; +import java.util.List; + /** * An Atom with only {@link Constant GroundTerms} for arguments. * @@ -34,6 +38,8 @@ public abstract class GroundAtom extends Atom implements Comparable, protected short partition; protected boolean fixed; + protected List terms; + protected GroundAtom(Predicate predicate, Constant[] args, float value, short partition) { super(predicate, args); @@ -47,10 +53,24 @@ protected GroundAtom(Predicate predicate, Constant[] args, float value, short pa this.index = -1; this.partition = partition; this.fixed = true; + + terms = new ArrayList(); } public abstract GroundAtom copy(); + public List getTerms() { + return terms; + } + + public void addTerm(ReasonerTerm term) { + terms.add(term); + } + + public void clearTerms() { + terms.clear(); + } + @Override public Constant[] getArguments() { return (Constant[])arguments; diff --git a/psl-core/src/main/java/org/linqs/psl/model/atom/ObservedAtom.java b/psl-core/src/main/java/org/linqs/psl/model/atom/ObservedAtom.java index 101957567..7455bda66 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/atom/ObservedAtom.java +++ b/psl-core/src/main/java/org/linqs/psl/model/atom/ObservedAtom.java @@ -41,7 +41,10 @@ public ObservedAtom(Predicate predicate, Constant[] args, float value, short par } public GroundAtom copy() { - return new ObservedAtom(predicate, (Constant[])arguments, value, partition); + GroundAtom observedAtomCopy = new ObservedAtom(predicate, (Constant[])arguments, value, partition); + observedAtomCopy.terms.addAll(terms); + + return observedAtomCopy; } /** diff --git a/psl-core/src/main/java/org/linqs/psl/model/atom/RandomVariableAtom.java b/psl-core/src/main/java/org/linqs/psl/model/atom/RandomVariableAtom.java index abc18645a..a6847f544 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/atom/RandomVariableAtom.java +++ b/psl-core/src/main/java/org/linqs/psl/model/atom/RandomVariableAtom.java @@ -26,7 +26,7 @@ */ public class RandomVariableAtom extends GroundAtom { /** - * Instantiation of GrondAtoms should typically be left to the Database so it can maintain a cache. + * Instantiation of GroundAtoms should typically be left to the Database so it can maintain a cache. */ public RandomVariableAtom(StandardPredicate predicate, Constant[] args, float value, short partition) { super(predicate, args, value, partition); @@ -37,7 +37,10 @@ public RandomVariableAtom(StandardPredicate predicate, Constant[] args, float va @Override public GroundAtom copy() { - return new RandomVariableAtom((StandardPredicate)predicate, (Constant[])arguments, value, partition); + GroundAtom randomVariableAtomCopy = new RandomVariableAtom((StandardPredicate)predicate, (Constant[])arguments, value, partition); + randomVariableAtomCopy.terms.addAll(terms); + + return randomVariableAtomCopy; } @Override diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/admm/term/ADMMTermStore.java b/psl-core/src/main/java/org/linqs/psl/reasoner/admm/term/ADMMTermStore.java index 407a4102d..2defd75b2 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/admm/term/ADMMTermStore.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/admm/term/ADMMTermStore.java @@ -59,7 +59,7 @@ public int getNumLocalVariables() { @Override public synchronized int add(ReasonerTerm term) { - init(); + ensureLocalRecordsCapacity(); long termIndex = size(); super.add(term); @@ -132,6 +132,16 @@ private synchronized void init() { } } + private synchronized void ensureLocalRecordsCapacity() { + init(); + + if (localRecords.length <= atomStore.getMaxRVAIndex()) { + List[] newLocalRecords = new List[2 * (atomStore.getMaxRVAIndex() + 1)]; + System.arraycopy(localRecords, 0, newLocalRecords, 0, localRecords.length); + localRecords = newLocalRecords; + } + } + public static final class LocalRecord { public long termIndex; public short variableIndex; diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java b/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java index e7d2ec83e..f2ee2385c 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/term/ReasonerTerm.java @@ -156,6 +156,11 @@ public int[] getAtomIndexes() { return atomIndexes; } + public void setAtomIndexes(int[] atomIndexes) { + assert (atomIndexes.length == size); + this.atomIndexes = atomIndexes; + } + /** * Get the coefficients of the atoms involved in this term. * The coefficients are aligned with the atomIndexes array, i.e., the i'th entry in the coefficient array diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java b/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java index 942f3e4ff..e2dda776a 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java @@ -43,6 +43,10 @@ public synchronized int add(ReasonerTerm term) { T newTerm = (T)term; terms.add(newTerm); + for (int atomIndex : newTerm.getAtomIndexes()) { + atomStore.getAtom(atomIndex).addTerm(newTerm); + } + return 1; }