diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 120b6b528..b2e58f14b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -24,6 +24,8 @@ import org.linqs.psl.config.Options; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; +import org.linqs.psl.model.atom.Atom; +import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.deep.DeepModelPredicate; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.rule.AbstractRule; @@ -40,6 +42,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -268,7 +271,19 @@ protected void initializeDeepWeightedRules() { for (WeightedRule rule : deepRules) { Set childHashCodes = ((AbstractRule) rule).getChildHashCodes(); + Set coreAtomsSet = new HashSet(); for (Integer childHashCode : childHashCodes) { + WeightedRule groundedWeightedRule = (WeightedRule) AbstractRule.getRule(childHashCode); + + // Verify that the atom in the deep weight is not in the term expression. + // This pattern is not supported in the current implementation. + GroundAtom deepWeightAtom = (GroundAtom) groundedWeightedRule.getWeight().getAtom(); + groundedWeightedRule.getCoreAtoms(coreAtomsSet); + if (coreAtomsSet.contains(deepWeightAtom)) { + throw new IllegalArgumentException("Grounded Deep weight atoms: " + deepWeightAtom + " cannot be in the expression of the rule: " + rule); + } + coreAtomsSet.clear(); + groundedDeepWeightedRules.add((WeightedRule) AbstractRule.getRule(childHashCode)); groundedDeepWeightedRuleIndexMap.put((WeightedRule) AbstractRule.getRule(childHashCode), groundedDeepWeightRuleCount); groundedDeepWeightRuleCount++;