Skip to content

Commit

Permalink
LearningBatchGenerators
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Jul 6, 2023
1 parent e4892ec commit 573d78f
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ public static enum GDExtension {
protected List<float[]> batchMAPAtomValueStates;


protected int numBatches;
protected List<SimpleTermStore<? extends ReasonerTerm>> batchTermStores;
protected LearningBatchGenerator batchGenerator;

protected TermState[] validationMAPTermState;
protected float[] validationMAPAtomValueState;
Expand Down Expand Up @@ -132,10 +131,9 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database
trainFullMAPAtomValueState = null;
trainMAPAtomValueState = null;

numBatches = Options.WLA_GRADIENT_DESCENT_NUM_BATCHES.getInt();
batchTermStores = new ArrayList<SimpleTermStore<? extends ReasonerTerm>>(numBatches);
batchMAPTermStates = new ArrayList<>(numBatches);
batchMAPAtomValueStates = new ArrayList<>(numBatches);
batchGenerator = null;
batchMAPTermStates = new ArrayList<TermState[]>();
batchMAPAtomValueStates = new ArrayList<float[]>();

validationMAPTermState = null;
validationMAPAtomValueState = null;
Expand Down Expand Up @@ -210,27 +208,12 @@ protected void postInitGroundModel() {
}
}

// ToDo(Charles): Create non-trivial batches. Currently, each batch is a copy of the full term store.
for (int i = 0; i < numBatches; i++) {
// Create a new term store and atom store for each batch.
AtomStore batchAtomStore = new AtomStore();
for (GroundAtom atom : trainFullTermStore.getAtomStore()) {
// Make a copy of the atom so that the batch atom store can be modified without affecting the full atom store.
batchAtomStore.addAtom(atom.copy());

}

SimpleTermStore<? extends ReasonerTerm> batchTermStore = (SimpleTermStore<? extends ReasonerTerm>)trainInferenceApplication.createTermStore();
for (ReasonerTerm term : trainFullTermStore) {
ReasonerTerm batchTerm = term.copy();

batchTermStore.add(batchTerm);
}
batchTermStore.setAtomStore(batchAtomStore);
batchTermStores.add(batchTermStore);
batchGenerator = new RandomNodeBatchGenerator(trainInferenceApplication);
batchGenerator.generateBatches();

for (SimpleTermStore<? extends ReasonerTerm> batchTermStore : batchGenerator.getBatchTermStores()) {
batchMAPTermStates.add(batchTermStore.saveState());
batchMAPAtomValueStates.add(Arrays.copyOf(batchAtomStore.getAtomValues(), batchAtomStore.getAtomValues().length));
batchMAPAtomValueStates.add(Arrays.copyOf(batchTermStore.getAtomStore().getAtomValues(), batchTermStore.getAtomStore().getAtomValues().length));
}
}

Expand Down Expand Up @@ -322,7 +305,7 @@ protected void doLearn() {
log.debug("MAP State Best Validation Evaluation Metric: {}", bestValidationEvaluationMetric);
}

for (int i = 0; i < numBatches; i++) {
for (int i = 0; i < batchGenerator.getNumBatches(); i++) {
setBatch(i);

computeIterationStatistics();
Expand Down Expand Up @@ -389,13 +372,15 @@ protected void doLearn() {
}

protected void setBatch(int batch) {
trainInferenceApplication.setTermStore(batchTermStores.get(batch));
SimpleTermStore<? extends ReasonerTerm> batchTermStore = batchGenerator.getBatchTermStore(batch);

trainInferenceApplication.setTermStore(batchTermStore);
trainMAPTermState = batchMAPTermStates.get(batch);
trainMAPAtomValueState = batchMAPAtomValueStates.get(batch);

// Set the deep predicate atom store and predict with the deep predicates again to ensure predictions are aligned with the batch.
for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.getDeepModel().setAtomStore(batchTermStores.get(batch).getAtomStore(), true);
deepPredicate.getDeepModel().setAtomStore(batchTermStore.getAtomStore(), true);
deepPredicate.predictDeepModel(true);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.learning.weight.gradient;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
* A class for generating batches of data for learning.
* A batch in this case is a set of atoms and terms defining a subgraph of the complete factor graph.
*/
public abstract class LearningBatchGenerator {
protected SimpleTermStore<? extends ReasonerTerm> fullTermStore;
protected List<SimpleTermStore<? extends ReasonerTerm>> batchTermStores;

protected InferenceApplication inferenceApplication;


public LearningBatchGenerator(InferenceApplication inferenceApplication) {
assert inferenceApplication.getTermStore() instanceof SimpleTermStore;
this.inferenceApplication = inferenceApplication;
this.fullTermStore = (SimpleTermStore<? extends ReasonerTerm>)inferenceApplication.getTermStore();

batchTermStores = new ArrayList<SimpleTermStore<? extends ReasonerTerm>>();
}

public void shuffle() {
Collections.shuffle(batchTermStores);
}

public int getNumBatches() {
return batchTermStores.size();
}

public List<SimpleTermStore<? extends ReasonerTerm>> getBatchTermStores() {
return batchTermStores;
}

public SimpleTermStore<? extends ReasonerTerm> getBatchTermStore(int index) {
return batchTermStores.get(index);
}

public abstract void generateBatches();

public void clear() {
for (SimpleTermStore<? extends ReasonerTerm> termStore : batchTermStores) {
termStore.getAtomStore().close();
termStore.clear();
}

batchTermStores.clear();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.learning.weight.gradient;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;

import java.util.*;

public class RandomNodeBatchGenerator extends LearningBatchGenerator {
private final int numBatches;
private final int batchSize;
private final int bfsDepth;

public RandomNodeBatchGenerator(InferenceApplication inferenceApplication) {
super(inferenceApplication);

numBatches = Options.RANDOM_NODE_BATCH_GENERATOR_NUM_BATCHES.getInt();
batchSize = (int) Math.ceil(((float) fullTermStore.getAtomStore().getNumRVAtoms()) / numBatches);
bfsDepth = Options.RANDOM_NODE_BATCH_GENERATOR_BFS_DEPTH.getInt();
}

@Override
public void generateBatches() {
// Clear out any old batches.
clear();

// Randomly sample batchSize number of random variable atoms from the full atom store and create a new term store and atom store for each batch.
ArrayList<GroundAtom> allAtoms = new ArrayList<GroundAtom>(Arrays.asList(Arrays.copyOf(fullTermStore.getAtomStore().getAtoms(), fullTermStore.getAtomStore().size())));
Collections.shuffle(allAtoms);

for (int i = 0; i < numBatches; i++) {
AtomStore batchAtomStore = new AtomStore();
SimpleTermStore<? extends ReasonerTerm> batchTermStore = (SimpleTermStore<? extends ReasonerTerm>)inferenceApplication.createTermStore();
batchTermStore.setAtomStore(batchAtomStore);
batchTermStores.add(batchTermStore);

HashSet<ReasonerTerm> visitedTerms = new HashSet<>();
HashSet<GroundAtom> visitedAtoms = new HashSet<>();

// The last batch may be smaller than the rest.
while ((batchAtomStore.size() < batchSize) && !allAtoms.isEmpty()) {
GroundAtom originalAtom = allAtoms.remove(0);

if (originalAtom.isFixed()) {
continue;
}

// Make a copy of the atom so that the batch atom store can be modified without affecting the full atom store.
RandomVariableAtom newBatchRVAtom = (RandomVariableAtom)originalAtom.copy();
newBatchRVAtom.clearTerms();

if (!visitedAtoms.contains(originalAtom)) {
batchAtomStore.addAtom(newBatchRVAtom);
}

// Perform a bfs on the factor graph starting from the sampled atoms to obtain batch terms.
ArrayList<ReasonerTerm> bfsCurrentDepthQueue = new ArrayList<ReasonerTerm>(originalAtom.getTerms());
for (int depth = 0; depth < bfsDepth; depth++) {
ArrayList<ReasonerTerm> bfsNextDepthQueue = new ArrayList<ReasonerTerm>();

for (ReasonerTerm term : bfsCurrentDepthQueue) {
if (visitedTerms.contains(term)) {
continue;
}

visitedTerms.add(term);

int[] originalAtomIndexes = term.getAtomIndexes();
int[] newAtomIndexes = new int[term.getAtomIndexes().length];
for (int j = 0 ; j < term.size(); j ++) {
int atomIndex = originalAtomIndexes[j];
GroundAtom atom = fullTermStore.getAtomStore().getAtom(atomIndex);

if (visitedAtoms.contains(atom)) {
newAtomIndexes[j] = batchAtomStore.getAtomIndex(atom);
continue;
}

visitedAtoms.add(atom);

GroundAtom newBatchAtom = atom.copy();
newBatchAtom.clearTerms();
batchAtomStore.addAtom(newBatchAtom);
newAtomIndexes[j] = batchAtomStore.getAtomIndex(atom);

bfsNextDepthQueue.addAll(atom.getTerms());
}

ReasonerTerm newBatchTerm = term.copy();
newBatchTerm.setAtomIndexes(newAtomIndexes);
batchTermStore.add(newBatchTerm);
}

bfsCurrentDepthQueue = bfsNextDepthQueue;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
package org.linqs.psl.application.learning.weight.gradient.optimalvalue;

import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.application.learning.weight.gradient.minimizer.Minimizer;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;
Expand Down Expand Up @@ -69,8 +66,8 @@ public OptimalValue(List<Rule> rules, Database trainTargetDatabase, Database tra
latentFullInferenceAtomValueState = null;
latentInferenceAtomValueState = null;

batchLatentInferenceTermStates = new ArrayList<>(numBatches);
batchlatentInferenceAtomValueState = new ArrayList<>(numBatches);
batchLatentInferenceTermStates = new ArrayList<TermState[]>();
batchlatentInferenceAtomValueState = new ArrayList<float[]>();
}

@Override
Expand All @@ -88,8 +85,8 @@ protected void postInitGroundModel() {
rvLatentAtomGradient = new float[atomValues.length];
deepLatentAtomGradient = new float[atomValues.length];

for (int i = 0; i < numBatches; i++) {
SimpleTermStore<? extends ReasonerTerm> batchTermStore = batchTermStores.get(i);
for (int i = 0; i < batchGenerator.getNumBatches(); i++) {
SimpleTermStore<? extends ReasonerTerm> batchTermStore = batchGenerator.getBatchTermStore(i);
batchLatentInferenceTermStates.add(batchTermStore.saveState());
batchlatentInferenceAtomValueState.add(Arrays.copyOf(batchTermStore.getAtomStore().getAtomValues(), batchTermStore.getAtomStore().getAtomValues().length));
}
Expand Down Expand Up @@ -140,6 +137,12 @@ protected void fixLabeledRandomVariables() {
ObservedAtom observedAtom = entry.getValue();

int atomIndex = atomStore.getAtomIndex(randomVariableAtom);

if (atomIndex == -1) {
// This atom is not in the current batch.
continue;
}

atomStore.getAtoms()[atomIndex] = observedAtom;
atomStore.getAtomValues()[atomIndex] = observedAtom.getValue();
latentInferenceAtomValueState[atomIndex] = observedAtom.getValue();
Expand All @@ -159,6 +162,12 @@ protected void unfixLabeledRandomVariables() {
RandomVariableAtom randomVariableAtom = entry.getKey();

int atomIndex = atomStore.getAtomIndex(randomVariableAtom);

if (atomIndex == -1) {
// This atom is not in the current batch.
continue;
}

atomStore.getAtoms()[atomIndex] = randomVariableAtom;
}
}
Expand Down
19 changes: 12 additions & 7 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,6 @@ public class Options {
Option.FLAG_POSITIVE
);

public static final Option WLA_GRADIENT_DESCENT_NUM_BATCHES = new Option(
"gradientdescent.numbatches",
1,
"The number of batches to use for gradient descent weight learning."
+ " The default is 1, which means that all training examples are used in each iteration."
);

public static final Option WLA_GRADIENT_DESCENT_NUM_STEPS = new Option(
"gradientdescent.numsteps",
500,
Expand Down Expand Up @@ -708,6 +701,18 @@ public class Options {
Option.FLAG_POSITIVE
);

public static final Option RANDOM_NODE_BATCH_GENERATOR_NUM_BATCHES = new Option(
"randomnodebatchgenerator.numbatches",
10,
"The number of batches to sample for random node batch generator."
);

public static final Option RANDOM_NODE_BATCH_GENERATOR_BFS_DEPTH = new Option(
"randomnodebatchgenerator.bfsdepth",
1,
"The depth of the factor graph bfs search for random node batch generator."
);

public static final Option EVAL_AUC_REPRESENTATIVE = new Option(
"aucevaluator.representative",
AUCEvaluator.RepresentativeMetric.AUROC.toString(),
Expand Down
Loading

0 comments on commit 573d78f

Please sign in to comment.