-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
291 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
.../main/java/org/linqs/psl/application/learning/weight/gradient/LearningBatchGenerator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* This file is part of the PSL software. | ||
* Copyright 2011-2015 University of Maryland | ||
* Copyright 2013-2023 The Regents of the University of California | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.linqs.psl.application.learning.weight.gradient; | ||
|
||
import org.linqs.psl.application.inference.InferenceApplication; | ||
import org.linqs.psl.reasoner.term.ReasonerTerm; | ||
import org.linqs.psl.reasoner.term.SimpleTermStore; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
|
||
/** | ||
* A class for generating batches of data for learning. | ||
* A batch in this case is a set of atoms and terms defining a subgraph of the complete factor graph. | ||
*/ | ||
public abstract class LearningBatchGenerator { | ||
protected SimpleTermStore<? extends ReasonerTerm> fullTermStore; | ||
protected List<SimpleTermStore<? extends ReasonerTerm>> batchTermStores; | ||
|
||
protected InferenceApplication inferenceApplication; | ||
|
||
|
||
public LearningBatchGenerator(InferenceApplication inferenceApplication) { | ||
assert inferenceApplication.getTermStore() instanceof SimpleTermStore; | ||
this.inferenceApplication = inferenceApplication; | ||
this.fullTermStore = (SimpleTermStore<? extends ReasonerTerm>)inferenceApplication.getTermStore(); | ||
|
||
batchTermStores = new ArrayList<SimpleTermStore<? extends ReasonerTerm>>(); | ||
} | ||
|
||
public void shuffle() { | ||
Collections.shuffle(batchTermStores); | ||
} | ||
|
||
public int getNumBatches() { | ||
return batchTermStores.size(); | ||
} | ||
|
||
public List<SimpleTermStore<? extends ReasonerTerm>> getBatchTermStores() { | ||
return batchTermStores; | ||
} | ||
|
||
public SimpleTermStore<? extends ReasonerTerm> getBatchTermStore(int index) { | ||
return batchTermStores.get(index); | ||
} | ||
|
||
public abstract void generateBatches(); | ||
|
||
public void clear() { | ||
for (SimpleTermStore<? extends ReasonerTerm> termStore : batchTermStores) { | ||
termStore.getAtomStore().close(); | ||
termStore.clear(); | ||
} | ||
|
||
batchTermStores.clear(); | ||
} | ||
} |
120 changes: 120 additions & 0 deletions
120
...ain/java/org/linqs/psl/application/learning/weight/gradient/RandomNodeBatchGenerator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* This file is part of the PSL software. | ||
* Copyright 2011-2015 University of Maryland | ||
* Copyright 2013-2023 The Regents of the University of California | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.linqs.psl.application.learning.weight.gradient; | ||
|
||
import org.linqs.psl.application.inference.InferenceApplication; | ||
import org.linqs.psl.config.Options; | ||
import org.linqs.psl.database.AtomStore; | ||
import org.linqs.psl.model.atom.GroundAtom; | ||
import org.linqs.psl.model.atom.RandomVariableAtom; | ||
import org.linqs.psl.reasoner.term.ReasonerTerm; | ||
import org.linqs.psl.reasoner.term.SimpleTermStore; | ||
|
||
import java.util.*; | ||
|
||
public class RandomNodeBatchGenerator extends LearningBatchGenerator { | ||
private final int numBatches; | ||
private final int batchSize; | ||
private final int bfsDepth; | ||
|
||
public RandomNodeBatchGenerator(InferenceApplication inferenceApplication) { | ||
super(inferenceApplication); | ||
|
||
numBatches = Options.RANDOM_NODE_BATCH_GENERATOR_NUM_BATCHES.getInt(); | ||
batchSize = (int) Math.ceil(((float) fullTermStore.getAtomStore().getNumRVAtoms()) / numBatches); | ||
bfsDepth = Options.RANDOM_NODE_BATCH_GENERATOR_BFS_DEPTH.getInt(); | ||
} | ||
|
||
@Override | ||
public void generateBatches() { | ||
// Clear out any old batches. | ||
clear(); | ||
|
||
// Randomly sample batchSize number of random variable atoms from the full atom store and create a new term store and atom store for each batch. | ||
ArrayList<GroundAtom> allAtoms = new ArrayList<GroundAtom>(Arrays.asList(Arrays.copyOf(fullTermStore.getAtomStore().getAtoms(), fullTermStore.getAtomStore().size()))); | ||
Collections.shuffle(allAtoms); | ||
|
||
for (int i = 0; i < numBatches; i++) { | ||
AtomStore batchAtomStore = new AtomStore(); | ||
SimpleTermStore<? extends ReasonerTerm> batchTermStore = (SimpleTermStore<? extends ReasonerTerm>)inferenceApplication.createTermStore(); | ||
batchTermStore.setAtomStore(batchAtomStore); | ||
batchTermStores.add(batchTermStore); | ||
|
||
HashSet<ReasonerTerm> visitedTerms = new HashSet<>(); | ||
HashSet<GroundAtom> visitedAtoms = new HashSet<>(); | ||
|
||
// The last batch may be smaller than the rest. | ||
while ((batchAtomStore.size() < batchSize) && !allAtoms.isEmpty()) { | ||
GroundAtom originalAtom = allAtoms.remove(0); | ||
|
||
if (originalAtom.isFixed()) { | ||
continue; | ||
} | ||
|
||
// Make a copy of the atom so that the batch atom store can be modified without affecting the full atom store. | ||
RandomVariableAtom newBatchRVAtom = (RandomVariableAtom)originalAtom.copy(); | ||
newBatchRVAtom.clearTerms(); | ||
|
||
if (!visitedAtoms.contains(originalAtom)) { | ||
batchAtomStore.addAtom(newBatchRVAtom); | ||
} | ||
|
||
// Perform a bfs on the factor graph starting from the sampled atoms to obtain batch terms. | ||
ArrayList<ReasonerTerm> bfsCurrentDepthQueue = new ArrayList<ReasonerTerm>(originalAtom.getTerms()); | ||
for (int depth = 0; depth < bfsDepth; depth++) { | ||
ArrayList<ReasonerTerm> bfsNextDepthQueue = new ArrayList<ReasonerTerm>(); | ||
|
||
for (ReasonerTerm term : bfsCurrentDepthQueue) { | ||
if (visitedTerms.contains(term)) { | ||
continue; | ||
} | ||
|
||
visitedTerms.add(term); | ||
|
||
int[] originalAtomIndexes = term.getAtomIndexes(); | ||
int[] newAtomIndexes = new int[term.getAtomIndexes().length]; | ||
for (int j = 0 ; j < term.size(); j ++) { | ||
int atomIndex = originalAtomIndexes[j]; | ||
GroundAtom atom = fullTermStore.getAtomStore().getAtom(atomIndex); | ||
|
||
if (visitedAtoms.contains(atom)) { | ||
newAtomIndexes[j] = batchAtomStore.getAtomIndex(atom); | ||
continue; | ||
} | ||
|
||
visitedAtoms.add(atom); | ||
|
||
GroundAtom newBatchAtom = atom.copy(); | ||
newBatchAtom.clearTerms(); | ||
batchAtomStore.addAtom(newBatchAtom); | ||
newAtomIndexes[j] = batchAtomStore.getAtomIndex(atom); | ||
|
||
bfsNextDepthQueue.addAll(atom.getTerms()); | ||
} | ||
|
||
ReasonerTerm newBatchTerm = term.copy(); | ||
newBatchTerm.setAtomIndexes(newAtomIndexes); | ||
batchTermStore.add(newBatchTerm); | ||
} | ||
|
||
bfsCurrentDepthQueue = bfsNextDepthQueue; | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.