diff --git a/src/pyk/cterm.py b/src/pyk/cterm.py index f02e9795e..26c34c1c3 100644 --- a/src/pyk/cterm.py +++ b/src/pyk/cterm.py @@ -142,16 +142,41 @@ def _ml_impl(antecedents: Iterable[KInner], consequents: Iterable[KInner]) -> KI def add_constraint(self, new_constraint: KInner) -> CTerm: return CTerm(self.config, [new_constraint] + list(self.constraints)) - def anti_unify(self, other_term: CTerm, kdef: KDefinition | None = None) -> KInner: + def anti_unify( + self, other: CTerm, keep_values: bool = False, kdef: KDefinition | None = None + ) -> tuple[CTerm, CSubst, CSubst]: def disjunction_from_substs(subst1: Subst, subst2: Subst) -> KInner: if KToken('true', 'Bool') in [subst1.pred, subst2.pred]: return mlTop() return mlEqualsTrue(orBool([subst1.pred, subst2.pred])) - new_config, self_subst, other_subst = anti_unify(self.config, other_term.config, kdef) - constraints = [c for c in self.constraints if c in other_term.constraints] - constraints.append(disjunction_from_substs(self_subst, other_subst)) - return mlAnd([new_config] + constraints) + new_config, self_subst, other_subst = anti_unify(self.config, other.config, kdef=kdef) + common_constraints = [constraint for constraint in self.constraints if constraint in other.constraints] + + if keep_values: + new_constraints = common_constraints + new_constraints.append(disjunction_from_substs(self_subst, other_subst)) + else: + new_constraints = [] + fvs = free_vars(new_config) + len_fvs = 0 + while len_fvs < len(fvs): + len_fvs = len(fvs) + for constraint in common_constraints: + if constraint not in new_constraints: + constraint_fvs = free_vars(constraint) + if any(fv in fvs for fv in constraint_fvs): + new_constraints.append(constraint) + fvs.extend(constraint_fvs) + + new_cterm = CTerm(config=new_config, constraints=new_constraints) + self_csubst = new_cterm.match_with_constraint(self) + other_csubst = new_cterm.match_with_constraint(other) + if self_csubst is None or other_csubst is None: + raise ValueError( + f'Anti-unification failed to produce a more general state: {(new_cterm, (self, self_csubst), (other, other_csubst))}' + ) + return (new_cterm, self_csubst, other_csubst) def anti_unify(state1: KInner, state2: KInner, kdef: KDefinition | None = None) -> tuple[KInner, Subst, Subst]: diff --git a/src/tests/integration/kcfg/test_imp.py b/src/tests/integration/kcfg/test_imp.py index bbdb7a631..79093239c 100644 --- a/src/tests/integration/kcfg/test_imp.py +++ b/src/tests/integration/kcfg/test_imp.py @@ -1148,26 +1148,86 @@ def test_fail_fast( assert len(proof.terminal) == 1 assert len(proof.failing) == 1 - def test_anti_unify( + def test_anti_unify_forget_values( self, + kcfg_explore: KCFGExplore, + kprint: KPrint, + ) -> None: + cterm1 = self.config( + kprint=kprint, + k='int $n ; { }', + state='N |-> X:Int', + constraint=mlAnd( + [ + mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])), + ] + ), + ) + cterm2 = self.config( + kprint=kprint, + k='int $n ; { }', + state='N |-> Y:Int', + constraint=mlAnd( + [ + mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])), + ] + ), + ) + + anti_unifier, subst1, subst2 = cterm1.anti_unify(cterm2, keep_values=False, kdef=kprint.definition) + + k_cell = get_cell(anti_unifier.kast, 'STATE_CELL') + assert type(k_cell) is KApply + assert k_cell.label.name == '_|->_' + assert type(k_cell.args[1]) is KVariable + abstracted_var: KVariable = k_cell.args[1] + + expected_anti_unifier = self.config( + kprint=kprint, + k='int $n ; { }', + state=f'N |-> {abstracted_var.name}:Int', + constraint=mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])), + ) + + assert anti_unifier.kast == expected_anti_unifier.kast + + def test_anti_unify_keep_values( + self, + kcfg_explore: KCFGExplore, kprint: KPrint, ) -> None: cterm1 = self.config( kprint=kprint, k='int $n ; { }', - state='$s |-> 0', - constraint=mlEqualsTrue(KApply('_==K_', [KToken('1', 'Int'), KToken('1', 'Int')])), + state='N |-> X:Int', + constraint=mlAnd( + [ + mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])), + ] + ), ) cterm2 = self.config( kprint=kprint, k='int $n ; { }', - state='$s |-> 1', - constraint=mlEqualsTrue(KApply('_==K_', [KToken('1', 'Int'), KToken('1', 'Int')])), + state='N |-> Y:Int', + constraint=mlAnd( + [ + mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])), + ] + ), ) - anti_unifier = cterm1.anti_unify(cterm2, kprint.definition) + anti_unifier, subst1, subst2 = cterm1.anti_unify(cterm2, keep_values=True, kdef=kprint.definition) - k_cell = get_cell(anti_unifier, 'STATE_CELL') + k_cell = get_cell(anti_unifier.kast, 'STATE_CELL') assert type(k_cell) is KApply assert k_cell.label.name == '_|->_' assert type(k_cell.args[1]) is KVariable @@ -1176,41 +1236,44 @@ def test_anti_unify( expected_anti_unifier = self.config( kprint=kprint, k='int $n ; { }', - state=f'$s |-> {abstracted_var.name}:Int', + state=f'N |-> {abstracted_var.name}:Int', constraint=mlAnd( [ - mlEqualsTrue(KApply('_==K_', [KToken('1', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('N', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('X', 'Int'), KToken('1', 'Int')])), + mlEqualsTrue(KApply('_>Int_', [KVariable('Y', 'Int'), KToken('1', 'Int')])), mlEqualsTrue( orBool( [ - KApply('_==K_', [KVariable(name=abstracted_var.name), intToken(0)]), - KApply('_==K_', [KVariable(name=abstracted_var.name), intToken(1)]), + KApply('_==K_', [KVariable(name=abstracted_var.name), KVariable('X', 'Int')]), + KApply('_==K_', [KVariable(name=abstracted_var.name), KVariable('Y', 'Int')]), ] ) ), ] ), - ).kast + ) - assert anti_unifier == expected_anti_unifier + assert anti_unifier.kast == expected_anti_unifier.kast def test_anti_unify_subst_true( self, + kcfg_explore: KCFGExplore, kprint: KPrint, ) -> None: cterm1 = self.config( kprint=kprint, k='int $n ; { }', - state='$s |-> 0', - constraint=mlEqualsTrue(KApply('_==K_', [KToken('1', 'Int'), KToken('1', 'Int')])), + state='N |-> 0', + constraint=mlEqualsTrue(KApply('_==K_', [KVariable('N', 'Int'), KToken('1', 'Int')])), ) cterm2 = self.config( kprint=kprint, k='int $n ; { }', - state='$s |-> 0', - constraint=mlEqualsTrue(KApply('_==K_', [KToken('1', 'Int'), KToken('1', 'Int')])), + state='N |-> 0', + constraint=mlEqualsTrue(KApply('_==K_', [KVariable('N', 'Int'), KToken('1', 'Int')])), ) - anti_unifier = cterm1.anti_unify(cterm2, kprint.definition) + anti_unifier, _, _ = cterm1.anti_unify(cterm2, keep_values=True, kdef=kprint.definition) - assert anti_unifier == cterm1.kast + assert anti_unifier.kast == cterm1.kast