Skip to content

Commit

Permalink
Improvements/refactors to substitution extraction routine (#4631)
Browse files Browse the repository at this point in the history
Pulled out of: #4625

This PR improves the substitution extraction machinery in `kast.manip`,
and adds tests. This isn't used anywhere at the moment, but #4625 will
start using it heavily.

- The code for `extract_substs` is simplified.
- The cases of circular substitutions are handled slightly more
gracefully.

---------

Co-authored-by: rv-jenkins <[email protected]>
  • Loading branch information
ehildenb and rv-jenkins authored Sep 6, 2024
1 parent 4ec14b9 commit c54f424
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 39 deletions.
65 changes: 28 additions & 37 deletions pyk/src/pyk/kast/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,48 +215,39 @@ def extract_rhs(term: KInner) -> KInner:


def extract_subst(term: KInner) -> tuple[Subst, KInner]:
def _subst_for_terms(term1: KInner, term2: KInner) -> Subst | None:
if type(term1) is KVariable and type(term2) not in {KToken, KVariable}:
return Subst({term1.name: term2})
if type(term2) is KVariable and type(term1) not in {KToken, KVariable}:
return Subst({term2.name: term1})
return None

def _extract_subst(conjunct: KInner) -> Subst | None:
if type(conjunct) is KApply:
if conjunct.label.name == '#Equals':
subst = _subst_for_terms(conjunct.args[0], conjunct.args[1])

if subst is not None:
return subst

if (
conjunct.args[0] == TRUE
and type(conjunct.args[1]) is KApply
and conjunct.args[1].label.name in {'_==K_', '_==Int_'}
):
subst = _subst_for_terms(conjunct.args[1].args[0], conjunct.args[1].args[1])

if subst is not None:
return subst
_subst = {}
rem_conjuncts: list[KInner] = []

def _extract_subst(_term1: KInner, _term2: KInner) -> tuple[str, KInner] | None:
if (
(type(_term1) is KVariable and _term1.name not in _subst)
and not (type(_term2) is KVariable and _term2.name in _subst)
and _term1.name not in free_vars(_term2)
):
return (_term1.name, _term2)
if (
(type(_term2) is KVariable and _term2.name not in _subst)
and not (type(_term1) is KVariable and _term1.name in _subst)
and _term2.name not in free_vars(_term1)
):
return (_term2.name, _term1)
if _term1 == TRUE and type(_term2) is KApply and _term2.label.name in {'_==K_', '_==Int_'}:
return _extract_subst(_term2.args[0], _term2.args[1])
if _term2 == TRUE and type(_term1) is KApply and _term1.label.name in {'_==K_', '_==Int_'}:
return _extract_subst(_term1.args[0], _term1.args[1])
return None

conjuncts = flatten_label('#And', term)
subst = Subst()
rem_conjuncts: list[KInner] = []

for conjunct in conjuncts:
new_subst = _extract_subst(conjunct)
if new_subst is None:
rem_conjuncts.append(conjunct)
for conjunct in flatten_label('#And', term):
if type(conjunct) is KApply and conjunct.label.name == '#Equals':
if _conjunct_subst := _extract_subst(conjunct.args[0], conjunct.args[1]):
name, value = _conjunct_subst
_subst[name] = value
else:
rem_conjuncts.append(conjunct)
else:
new_subst = subst.union(new_subst)
if new_subst is None:
raise ValueError('Conflicting substitutions') # TODO handle this case
subst = new_subst
rem_conjuncts.append(conjunct)

return subst, mlAnd(rem_conjuncts)
return Subst(_subst), mlAnd(rem_conjuncts)


def count_vars(term: KInner) -> Counter[str]:
Expand Down
7 changes: 5 additions & 2 deletions pyk/src/tests/unit/kast/test_subst.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def test_ml_pred(test_id: str, subst: Subst, pred: KInner) -> None:
(a, {}, a),
(mlEquals(a, b), {}, mlEquals(a, b)),
(mlEquals(x, a), {'x': a}, mlTop()),
(mlEquals(x, _0), {}, mlEquals(x, _0)),
(mlEquals(x, y), {}, mlEquals(x, y)),
(mlEquals(x, _0), {'x': _0}, mlTop()),
(mlEquals(x, y), {'x': y}, mlTop()),
(mlEquals(x, f(x)), {}, mlEquals(x, f(x))),
(mlAnd([mlEquals(x, y), mlEquals(x, b)]), {'x': y}, mlEquals(x, b)),
(mlAnd([mlEquals(x, b), mlEquals(x, y)]), {'x': b}, mlEquals(x, y)),
(mlAnd([mlEquals(a, b), mlEquals(x, a)]), {'x': a}, mlEquals(a, b)),
(mlEqualsTrue(_EQ(a, b)), {}, mlEqualsTrue(_EQ(a, b))),
(mlEqualsTrue(_EQ(x, a)), {'x': a}, mlTop()),
Expand Down

0 comments on commit c54f424

Please sign in to comment.