Skip to content

Commit

Permalink
Use the same neural model for validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Aug 14, 2023
1 parent e968c41 commit 5118618
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.Logger;
Expand Down Expand Up @@ -54,6 +57,10 @@ public abstract class WeightLearningApplication implements ModelApplication {
protected Database validationTargetDatabase;
protected Database validationTruthDatabase;

protected List<DeepPredicate> deepPredicates;
protected List<DeepModelPredicate> deepModelPredicates;
protected List<DeepModelPredicate> validationDeepModelPredicates;

protected boolean runValidation;

protected List<Rule> allRules;
Expand Down Expand Up @@ -87,6 +94,10 @@ public WeightLearningApplication(List<Rule> rules, Database trainTargetDatabase,

this.runValidation = runValidation;

deepPredicates = new ArrayList<DeepPredicate>();
deepModelPredicates = new ArrayList<DeepModelPredicate>();
validationDeepModelPredicates = new ArrayList<DeepModelPredicate>();

allRules = new ArrayList<Rule>();
mutableRules = new ArrayList<WeightedRule>();

Expand Down Expand Up @@ -154,11 +165,8 @@ protected void initGroundModel() {

InferenceApplication trainInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), allRules, trainTargetDatabase);
trainInferenceApplication.loadDeepPredicates("learning");
InferenceApplication validationInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), allRules, validationTargetDatabase);

if (runValidation) {
validationInferenceApplication.loadDeepPredicates("inference");
}
InferenceApplication validationInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), allRules, validationTargetDatabase);

initGroundModel(trainInferenceApplication, validationInferenceApplication);
}
Expand Down Expand Up @@ -195,6 +203,17 @@ public void initGroundModel(InferenceApplication trainInferenceApplication, Trai
initRandomWeights();
}

for (Predicate predicate : Predicate.getAll()) {
if (predicate instanceof DeepPredicate) {
deepPredicates.add((DeepPredicate)predicate);
deepModelPredicates.add(((DeepPredicate)predicate).getDeepModel());

DeepModelPredicate validationDeepModelPredicate = ((DeepPredicate)predicate).getDeepModel().copy();
validationDeepModelPredicate.setAtomStore(validationInferenceApplication.getDatabase().getAtomStore(), true);
validationDeepModelPredicates.add(((DeepPredicate)predicate).getDeepModel().copy());
}
}

postInitGroundModel();

groundModelInit = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.GroundRule;
Expand Down Expand Up @@ -67,8 +68,6 @@ public static enum GDExtension {
protected float[] MAPRVAtomGradient;
protected float[] MAPDeepAtomGradient;

protected List<DeepPredicate> deepPredicates;

protected TermState[] trainMAPTermState;
protected float[] trainMAPAtomValueState;

Expand Down Expand Up @@ -111,8 +110,6 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database
MAPRVAtomGradient = null;
MAPDeepAtomGradient = null;

deepPredicates = new ArrayList<DeepPredicate>();

trainMAPTermState = null;
trainMAPAtomValueState = null;

Expand Down Expand Up @@ -176,12 +173,6 @@ protected void postInitGroundModel() {

MAPRVAtomGradient = new float[trainAtomValues.length];
MAPDeepAtomGradient = new float[trainAtomValues.length];

for (Predicate predicate : Predicate.getAll()) {
if (predicate instanceof DeepPredicate) {
deepPredicates.add((DeepPredicate)predicate);
}
}
}

protected void initForLearning() {
Expand Down Expand Up @@ -238,6 +229,7 @@ protected void doLearn() {
gradientStep(iteration);

if (log.isTraceEnabled() && (evaluation != null)) {
log.trace("Running Inference.");
// Compute the MAP state before evaluating so variables have assigned values.
computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState);
inTrainingMAPState = true;
Expand All @@ -257,17 +249,16 @@ protected void doLearn() {
}

if (runValidation) {
for (DeepPredicate deepPredicate : deepPredicates) {
for (int i = 0; i < deepPredicates.size(); i++) {
DeepPredicate deepPredicate = deepPredicates.get(i);
deepPredicate.setDeepModel(validationDeepModelPredicates.get(i));
deepPredicate.predictDeepModel(false);
}

log.trace("Running Validation Inference.");
computeMAPStateWithWarmStart(validationInferenceApplication, validationMAPTermState, validationMAPAtomValueState);
inValidationMAPState = true;

for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.predictDeepModel(true);
}

evaluation.compute(validationMap);
currentValidationEvaluationMetric = evaluation.getNormalizedRepMetric();
log.debug("MAP State Validation Evaluation Metric: {}", currentValidationEvaluationMetric);
Expand All @@ -284,6 +275,12 @@ protected void doLearn() {
}

log.debug("MAP State Best Validation Evaluation Metric: {}", bestValidationEvaluationMetric);

for (int i = 0; i < deepPredicates.size(); i++) {
DeepPredicate deepPredicate = deepPredicates.get(i);
deepPredicate.setDeepModel(deepModelPredicates.get(i));
deepPredicate.predictDeepModel(true);
}
}

computeIterationStatistics();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ protected void computeProxRuleObservedAtomValueGradient() {
* Compute the incompatibility of the mpe state and the gradient of the energy function at the mpe state.
*/
private void computeFullInferenceStatistics() {
log.trace("Running Full Inference.");
log.trace("Running Inference.");
computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState);
inTrainingMAPState = true;

Expand Down
22 changes: 11 additions & 11 deletions psl-core/src/main/java/org/linqs/psl/model/deep/DeepModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ public abstract class DeepModel {
private static int startingPort = -1;
private static Map<Integer, DeepModel> usedPorts = null;

private String deepModel;
protected String deepModel;
protected Map<String, String> pythonOptions;
protected String application;

private int port;
private String pythonModule;
private String sharedMemoryPath;
private Process pythonServerProcess;
private RandomAccessFile sharedFile;
protected int port;
protected String pythonModule;
protected String sharedMemoryPath;
protected Process pythonServerProcess;
protected RandomAccessFile sharedFile;
protected MappedByteBuffer sharedBuffer;
private Socket socket;
private BufferedReader socketInput;
private PrintWriter socketOutput;
private boolean serverOpen;
protected Socket socket;
protected BufferedReader socketInput;
protected PrintWriter socketOutput;
protected boolean serverOpen;

protected DeepModel(String deepModel) {
this.deepModel = deepModel;
Expand Down Expand Up @@ -404,7 +404,7 @@ private static synchronized int getOpenPort(DeepModel model) {
return port;
}

private static synchronized void freePort(int port) {
protected static synchronized void freePort(int port) {
usedPorts.remove(Integer.valueOf(port));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,41 @@ public DeepModelPredicate(Predicate predicate) {
this.validDataIndexes = new ArrayList<Integer>();
}

public DeepModelPredicate copy() {
DeepModelPredicate copy = new DeepModelPredicate(predicate);

copy.pythonOptions = pythonOptions;

copy.application = application;

freePort(copy.port);
copy.port = (this.port);

copy.pythonModule = pythonModule;
copy.sharedMemoryPath = sharedMemoryPath;
copy.pythonServerProcess = pythonServerProcess;
copy.sharedFile = sharedFile;
copy.sharedBuffer = sharedBuffer;
copy.socket = socket;
copy.socketInput = socketInput;
copy.socketOutput = socketOutput;
copy.serverOpen = serverOpen;

copy.atomStore = atomStore;

copy.classSize = classSize;
copy.atomIndexes = atomIndexes;
copy.dataIndexes = dataIndexes;

copy.validAtomIndexes = validAtomIndexes;
copy.validDataIndexes = validDataIndexes;

copy.gradients = gradients;
copy.symbolicGradients = symbolicGradients;

return copy;
}

public synchronized int init() {
log.debug("Initializing deep model predicate: {}", predicate.getName());

Expand Down Expand Up @@ -160,7 +195,15 @@ public synchronized void close() {
}

public void setAtomStore(AtomStore atomStore) {
setAtomStore(atomStore, false);
}

public void setAtomStore(AtomStore atomStore, boolean init) {
this.atomStore = atomStore;

if (init) {
init();
}
}

public void setSymbolicGradients(float[] symbolicGradients) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ public void fitDeepPredicate(float[] symbolicGradients) {
deepModel.fitDeepModel();
}

public DeepModelPredicate getDeepModel() {
return deepModel;
}

public void setDeepModel(DeepModelPredicate deepModel) {
this.deepModel = deepModel;
}

public float predictDeepModel() {
return deepModel.predictDeepModel(false);
}
Expand Down

0 comments on commit 5118618

Please sign in to comment.