From b2b6a64e86e5811c40bf3cde0dd07b1722de3da1 Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Thu, 16 May 2024 16:12:05 -0700 Subject: [PATCH] Ensure deep weight atoms are not present in grounded rule expressions during learning. --- .../learning/weight/gradient/GradientDescent.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 120b6b528..b2e58f14b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -24,6 +24,8 @@ import org.linqs.psl.config.Options; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; +import org.linqs.psl.model.atom.Atom; +import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.deep.DeepModelPredicate; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.rule.AbstractRule; @@ -40,6 +42,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -268,7 +271,19 @@ protected void initializeDeepWeightedRules() { for (WeightedRule rule : deepRules) { Set childHashCodes = ((AbstractRule) rule).getChildHashCodes(); + Set coreAtomsSet = new HashSet(); for (Integer childHashCode : childHashCodes) { + WeightedRule groundedWeightedRule = (WeightedRule) AbstractRule.getRule(childHashCode); + + // Verify that the atom in the deep weight is not in the term expression. + // This pattern is not supported in the current implementation. + GroundAtom deepWeightAtom = (GroundAtom) groundedWeightedRule.getWeight().getAtom(); + groundedWeightedRule.getCoreAtoms(coreAtomsSet); + if (coreAtomsSet.contains(deepWeightAtom)) { + throw new IllegalArgumentException("Grounded Deep weight atoms: " + deepWeightAtom + " cannot be in the expression of the rule: " + rule); + } + coreAtomsSet.clear(); + groundedDeepWeightedRules.add((WeightedRule) AbstractRule.getRule(childHashCode)); groundedDeepWeightedRuleIndexMap.put((WeightedRule) AbstractRule.getRule(childHashCode), groundedDeepWeightRuleCount); groundedDeepWeightRuleCount++;