diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java index 480064e8b..ec07ef2e9 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java @@ -22,6 +22,9 @@ import org.linqs.psl.config.Options; import org.linqs.psl.database.Database; import org.linqs.psl.evaluation.EvaluationInstance; +import org.linqs.psl.model.deep.DeepModelPredicate; +import org.linqs.psl.model.predicate.DeepPredicate; +import org.linqs.psl.model.predicate.Predicate; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.model.rule.WeightedRule; import org.linqs.psl.util.Logger; @@ -54,6 +57,10 @@ public abstract class WeightLearningApplication implements ModelApplication { protected Database validationTargetDatabase; protected Database validationTruthDatabase; + protected List deepPredicates; + protected List deepModelPredicates; + protected List validationDeepModelPredicates; + protected boolean runValidation; protected List allRules; @@ -87,6 +94,10 @@ public WeightLearningApplication(List rules, Database trainTargetDatabase, this.runValidation = runValidation; + deepPredicates = new ArrayList(); + deepModelPredicates = new ArrayList(); + validationDeepModelPredicates = new ArrayList(); + allRules = new ArrayList(); mutableRules = new ArrayList(); @@ -154,11 +165,8 @@ protected void initGroundModel() { InferenceApplication trainInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), allRules, trainTargetDatabase); trainInferenceApplication.loadDeepPredicates("learning"); - InferenceApplication validationInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), allRules, validationTargetDatabase); - if (runValidation) { - validationInferenceApplication.loadDeepPredicates("inference"); - } + InferenceApplication validationInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), allRules, validationTargetDatabase); initGroundModel(trainInferenceApplication, validationInferenceApplication); } @@ -195,6 +203,17 @@ public void initGroundModel(InferenceApplication trainInferenceApplication, Trai initRandomWeights(); } + for (Predicate predicate : Predicate.getAll()) { + if (predicate instanceof DeepPredicate) { + deepPredicates.add((DeepPredicate)predicate); + deepModelPredicates.add(((DeepPredicate)predicate).getDeepModel()); + + DeepModelPredicate validationDeepModelPredicate = ((DeepPredicate)predicate).getDeepModel().copy(); + validationDeepModelPredicate.setAtomStore(validationInferenceApplication.getDatabase().getAtomStore(), true); + validationDeepModelPredicates.add(((DeepPredicate)predicate).getDeepModel().copy()); + } + } + postInitGroundModel(); groundModelInit = true; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 7e08df9ab..a1f09652c 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -22,6 +22,7 @@ import org.linqs.psl.config.Options; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; +import org.linqs.psl.model.deep.DeepModelPredicate; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.predicate.Predicate; import org.linqs.psl.model.rule.GroundRule; @@ -67,8 +68,6 @@ public static enum GDExtension { protected float[] MAPRVAtomGradient; protected float[] MAPDeepAtomGradient; - protected List deepPredicates; - protected TermState[] trainMAPTermState; protected float[] trainMAPAtomValueState; @@ -111,8 +110,6 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database MAPRVAtomGradient = null; MAPDeepAtomGradient = null; - deepPredicates = new ArrayList(); - trainMAPTermState = null; trainMAPAtomValueState = null; @@ -176,12 +173,6 @@ protected void postInitGroundModel() { MAPRVAtomGradient = new float[trainAtomValues.length]; MAPDeepAtomGradient = new float[trainAtomValues.length]; - - for (Predicate predicate : Predicate.getAll()) { - if (predicate instanceof DeepPredicate) { - deepPredicates.add((DeepPredicate)predicate); - } - } } protected void initForLearning() { @@ -238,6 +229,7 @@ protected void doLearn() { gradientStep(iteration); if (log.isTraceEnabled() && (evaluation != null)) { + log.trace("Running Inference."); // Compute the MAP state before evaluating so variables have assigned values. computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); inTrainingMAPState = true; @@ -257,17 +249,16 @@ protected void doLearn() { } if (runValidation) { - for (DeepPredicate deepPredicate : deepPredicates) { + for (int i = 0; i < deepPredicates.size(); i++) { + DeepPredicate deepPredicate = deepPredicates.get(i); + deepPredicate.setDeepModel(validationDeepModelPredicates.get(i)); deepPredicate.predictDeepModel(false); } + log.trace("Running Validation Inference."); computeMAPStateWithWarmStart(validationInferenceApplication, validationMAPTermState, validationMAPAtomValueState); inValidationMAPState = true; - for (DeepPredicate deepPredicate : deepPredicates) { - deepPredicate.predictDeepModel(true); - } - evaluation.compute(validationMap); currentValidationEvaluationMetric = evaluation.getNormalizedRepMetric(); log.debug("MAP State Validation Evaluation Metric: {}", currentValidationEvaluationMetric); @@ -284,6 +275,12 @@ protected void doLearn() { } log.debug("MAP State Best Validation Evaluation Metric: {}", bestValidationEvaluationMetric); + + for (int i = 0; i < deepPredicates.size(); i++) { + DeepPredicate deepPredicate = deepPredicates.get(i); + deepPredicate.setDeepModel(deepModelPredicates.get(i)); + deepPredicate.predictDeepModel(true); + } } computeIterationStatistics(); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java index 0518b8ab8..f15d3e447 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java @@ -362,7 +362,7 @@ protected void computeProxRuleObservedAtomValueGradient() { * Compute the incompatibility of the mpe state and the gradient of the energy function at the mpe state. */ private void computeFullInferenceStatistics() { - log.trace("Running Full Inference."); + log.trace("Running Inference."); computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); inTrainingMAPState = true; diff --git a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModel.java b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModel.java index baab4cdcb..f0cc7c16d 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModel.java +++ b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModel.java @@ -46,20 +46,20 @@ public abstract class DeepModel { private static int startingPort = -1; private static Map usedPorts = null; - private String deepModel; + protected String deepModel; protected Map pythonOptions; protected String application; - private int port; - private String pythonModule; - private String sharedMemoryPath; - private Process pythonServerProcess; - private RandomAccessFile sharedFile; + protected int port; + protected String pythonModule; + protected String sharedMemoryPath; + protected Process pythonServerProcess; + protected RandomAccessFile sharedFile; protected MappedByteBuffer sharedBuffer; - private Socket socket; - private BufferedReader socketInput; - private PrintWriter socketOutput; - private boolean serverOpen; + protected Socket socket; + protected BufferedReader socketInput; + protected PrintWriter socketOutput; + protected boolean serverOpen; protected DeepModel(String deepModel) { this.deepModel = deepModel; @@ -404,7 +404,7 @@ private static synchronized int getOpenPort(DeepModel model) { return port; } - private static synchronized void freePort(int port) { + protected static synchronized void freePort(int port) { usedPorts.remove(Integer.valueOf(port)); } } \ No newline at end of file diff --git a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java index c6cdb72b7..ac529a860 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java @@ -68,6 +68,41 @@ public DeepModelPredicate(Predicate predicate) { this.validDataIndexes = new ArrayList(); } + public DeepModelPredicate copy() { + DeepModelPredicate copy = new DeepModelPredicate(predicate); + + copy.pythonOptions = pythonOptions; + + copy.application = application; + + freePort(copy.port); + copy.port = (this.port); + + copy.pythonModule = pythonModule; + copy.sharedMemoryPath = sharedMemoryPath; + copy.pythonServerProcess = pythonServerProcess; + copy.sharedFile = sharedFile; + copy.sharedBuffer = sharedBuffer; + copy.socket = socket; + copy.socketInput = socketInput; + copy.socketOutput = socketOutput; + copy.serverOpen = serverOpen; + + copy.atomStore = atomStore; + + copy.classSize = classSize; + copy.atomIndexes = atomIndexes; + copy.dataIndexes = dataIndexes; + + copy.validAtomIndexes = validAtomIndexes; + copy.validDataIndexes = validDataIndexes; + + copy.gradients = gradients; + copy.symbolicGradients = symbolicGradients; + + return copy; + } + public synchronized int init() { log.debug("Initializing deep model predicate: {}", predicate.getName()); @@ -160,7 +195,15 @@ public synchronized void close() { } public void setAtomStore(AtomStore atomStore) { + setAtomStore(atomStore, false); + } + + public void setAtomStore(AtomStore atomStore, boolean init) { this.atomStore = atomStore; + + if (init) { + init(); + } } public void setSymbolicGradients(float[] symbolicGradients) { diff --git a/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java b/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java index fd1da3dd4..a229546bf 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java @@ -51,6 +51,14 @@ public void fitDeepPredicate(float[] symbolicGradients) { deepModel.fitDeepModel(); } + public DeepModelPredicate getDeepModel() { + return deepModel; + } + + public void setDeepModel(DeepModelPredicate deepModel) { + this.deepModel = deepModel; + } + public float predictDeepModel() { return deepModel.predictDeepModel(false); }