Skip to content

Commit

Permalink
Ensure deep weight atoms are not present in grounded rule expressions…
Browse files Browse the repository at this point in the history
… during learning.
  • Loading branch information
dickensc committed May 16, 2024
1 parent 71a0ca6 commit b2b6a64
Showing 1 changed file with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.rule.AbstractRule;
Expand All @@ -40,6 +42,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -268,7 +271,19 @@ protected void initializeDeepWeightedRules() {
for (WeightedRule rule : deepRules) {
Set<Integer> childHashCodes = ((AbstractRule) rule).getChildHashCodes();

Set<Atom> coreAtomsSet = new HashSet<Atom>();
for (Integer childHashCode : childHashCodes) {
WeightedRule groundedWeightedRule = (WeightedRule) AbstractRule.getRule(childHashCode);

// Verify that the atom in the deep weight is not in the term expression.
// This pattern is not supported in the current implementation.
GroundAtom deepWeightAtom = (GroundAtom) groundedWeightedRule.getWeight().getAtom();
groundedWeightedRule.getCoreAtoms(coreAtomsSet);
if (coreAtomsSet.contains(deepWeightAtom)) {
throw new IllegalArgumentException("Grounded Deep weight atoms: " + deepWeightAtom + " cannot be in the expression of the rule: " + rule);
}
coreAtomsSet.clear();

groundedDeepWeightedRules.add((WeightedRule) AbstractRule.getRule(childHashCode));
groundedDeepWeightedRuleIndexMap.put((WeightedRule) AbstractRule.getRule(childHashCode), groundedDeepWeightRuleCount);
groundedDeepWeightRuleCount++;
Expand Down

0 comments on commit b2b6a64

Please sign in to comment.