Skip to content

Commit

Permalink
Refactor rule parsing (#4699)
Browse files Browse the repository at this point in the history
* Remove unnecessary cases
* Make pattern matching logic more concise
* Process three cases of simplification rules: `AppRule`, `CeilRule`,
`EqualsRule`
  • Loading branch information
tothtamas28 authored Dec 3, 2024
1 parent 19a16cd commit e6a0d1d
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 116 deletions.
270 changes: 156 additions & 114 deletions pyk/src/pyk/kore/rule.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@
"""Parse KORE axioms into rewrite rules.
Based on the [LLVM Backend's implementation](https://github.com/runtimeverification/llvm-backend/blob/d5eab4b0f0e610bc60843ebb482f79c043b92702/lib/ast/pattern_matching.cpp).
"""

from __future__ import annotations

import logging
from abc import ABC
from dataclasses import dataclass
from typing import TYPE_CHECKING, final
from typing import TYPE_CHECKING, Generic, TypeVar, final

from .prelude import inj
from .syntax import And, App, Axiom, Ceil, Equals, EVar, Implies, In, Not, Rewrites, SortVar, String, Top
from .syntax import (
DV,
And,
App,
Axiom,
Ceil,
Equals,
EVar,
Implies,
In,
Not,
Pattern,
Rewrites,
SortApp,
SortVar,
String,
Top,
)

if TYPE_CHECKING:
from typing import Final

from .syntax import Definition, Pattern
from .syntax import Definition

Attrs = dict[str, tuple[Pattern, ...]]


P = TypeVar('P', bound=Pattern)


_LOGGER: Final = logging.getLogger(__name__)


# There's a simplification rule with irregular form in the prelude module INJ.
# This rule is skipped in Rule.extract_all.
_S1, _S2, _S3, _R = (SortVar(name) for name in ['S1', 'S2', 'S3', 'R'])
Expand Down Expand Up @@ -56,33 +75,35 @@ def from_axiom(axiom: Axiom) -> Rule:
if isinstance(axiom.pattern, Rewrites):
return RewriteRule.from_axiom(axiom)

if 'simplification' in axiom.attrs_by_key:
return SimpliRule.from_axiom(axiom)
if 'simplification' not in axiom.attrs_by_key:
return FunctionRule.from_axiom(axiom)

return FunctionRule.from_axiom(axiom)
match axiom.pattern:
case Implies(right=Equals(left=App())):
return AppRule.from_axiom(axiom)
case Implies(right=Equals(left=Ceil())):
return CeilRule.from_axiom(axiom)
case Implies(right=Equals(left=Equals())):
return EqualsRule.from_axiom(axiom)
case _:
raise ValueError(f'Cannot parse simplification rule: {axiom.text}')

@staticmethod
def extract_all(defn: Definition) -> list[Rule]:
return [Rule.from_axiom(axiom) for axiom in defn.axioms if Rule._is_rule(axiom)]

@staticmethod
def _is_rule(axiom: Axiom) -> bool:
if axiom == _INJ_AXIOM:
return False

if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS):
return False
def is_rule(axiom: Axiom) -> bool:
if axiom == _INJ_AXIOM:
return False

match axiom.pattern:
case Implies(right=Equals(left=Ceil())):
# Ceil rule
if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS):
return False

return True
return True

return [Rule.from_axiom(axiom) for axiom in defn.axioms if is_rule(axiom)]


@final
@dataclass
@dataclass(frozen=True)
class RewriteRule(Rule):
lhs: App
rhs: App
Expand All @@ -95,8 +116,7 @@ class RewriteRule(Rule):

@staticmethod
def from_axiom(axiom: Axiom) -> RewriteRule:
lhs, req, ctx = RewriteRule._extract_lhs(axiom)
rhs, ens = RewriteRule._extract_rhs(axiom)
lhs, rhs, req, ens, ctx = RewriteRule._extract(axiom)
priority = _extract_priority(axiom)
uid = _extract_uid(axiom)
label = _extract_label(axiom)
Expand All @@ -112,51 +132,35 @@ def from_axiom(axiom: Axiom) -> RewriteRule:
)

@staticmethod
def _extract_lhs(axiom: Axiom) -> tuple[App, Pattern | None, EVar | None]:
req: Pattern | None = None
# Cases 0-5 of get_left_hand_side
# Cases 5-10 of get_requires
def _extract(axiom: Axiom) -> tuple[App, App, Pattern | None, Pattern | None, EVar | None]:
match axiom.pattern:
case Rewrites(left=And(ops=(Top(), lhs))):
pass
case Rewrites(left=And(ops=(Equals(left=req), lhs))):
pass
case Rewrites(left=And(ops=(lhs, Top()))):
pass
case Rewrites(left=And(ops=(lhs, Equals(left=req)))):
pass
case Rewrites(left=And(ops=(Not(), And(ops=(Top(), lhs))))):
pass
case Rewrites(left=And(ops=(Not(), And(ops=(Equals(left=req), lhs))))):
case Rewrites(left=And(ops=(_lhs, _req)), right=_rhs):
pass
case _:
raise ValueError(f'Cannot extract LHS from axiom: {axiom.text}')
raise ValueError(f'Cannot extract rewrite rule from axiom: {axiom.text}')

ctx: EVar | None = None
match lhs:
case App("Lbl'-LT-'generatedTop'-GT-'") as app:
match _lhs:
case App("Lbl'-LT-'generatedTop'-GT-'") as lhs:
pass
case And(_, (App("Lbl'-LT-'generatedTop'-GT-'") as app, EVar("Var'Hash'Configuration") as ctx)):
case And(_, (App("Lbl'-LT-'generatedTop'-GT-'") as lhs, EVar("Var'Hash'Configuration") as ctx)):
pass
case _:
raise ValueError(f'Cannot extract LHS configuration from axiom: {axiom.text}')

return app, req, ctx

@staticmethod
def _extract_rhs(axiom: Axiom) -> tuple[App, Pattern | None]:
# Case 2 of get_right_hand_side:
# 2: rhs(\rewrites(_, \and(X, Y))) = get_builtin(\and(X, Y))
# Below is a special case without get_builtin
match axiom.pattern:
case Rewrites(right=And(ops=(App("Lbl'-LT-'generatedTop'-GT-'") as rhs, Top() | Equals() as _ens))):
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
match rhs:
case App("Lbl'-LT-'generatedTop'-GT-'"):
pass
case _:
raise ValueError(f'Cannot extract RHS from axiom: {axiom.text}')
ens = _extract_ensures(_ens)
return rhs, ens
raise ValueError(f'Cannot extract RHS configuration from axiom: {axiom.text}')

return lhs, rhs, req, ens, ctx


@final
@dataclass
@dataclass(frozen=True)
class FunctionRule(Rule):
lhs: App
rhs: Pattern
Expand All @@ -166,9 +170,7 @@ class FunctionRule(Rule):

@staticmethod
def from_axiom(axiom: Axiom) -> FunctionRule:
args, req = FunctionRule._extract_args(axiom)
app, rhs, ens = FunctionRule._extract_rhs(axiom)
lhs = app.let(args=args)
lhs, rhs, req, ens = FunctionRule._extract(axiom)
priority = _extract_priority(axiom)
return FunctionRule(
lhs=lhs,
Expand All @@ -179,93 +181,133 @@ def from_axiom(axiom: Axiom) -> FunctionRule:
)

@staticmethod
def _extract_args(axiom: Axiom) -> tuple[tuple[Pattern, ...], Pattern | None]:
req: Pattern | None = None
# Cases 7-10 of get_left_hand_side
# Cases 0-3 of get_requires
def _extract(axiom: Axiom) -> tuple[App, Pattern, Pattern | None, Pattern | None]:
match axiom.pattern:
case Implies(left=And(ops=(Top(), pat))):
return FunctionRule._get_patterns(pat), req
case Implies(left=And(ops=(Equals(left=req), pat))):
return FunctionRule._get_patterns(pat), req
case Implies(left=And(ops=(Not(), And(ops=(Top(), pat))))):
return FunctionRule._get_patterns(pat), req
case Implies(left=And(ops=(Not(), And(ops=(Equals(left=req), pat))))):
return FunctionRule._get_patterns(pat), req
case Implies(
left=And(
ops=(Not(), And(ops=(_req, _args))) | (_req, _args),
),
right=Equals(left=App() as app, right=_rhs),
):
args = FunctionRule._extract_args(_args)
lhs = app.let(args=args)
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
return lhs, rhs, req, ens
case _:
raise ValueError(f'Cannot extract LHS from axiom: {axiom.text}')
raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}')

@staticmethod
def _get_patterns(pattern: Pattern) -> tuple[Pattern, ...]:
# get_patterns(\top()) = []
# get_patterns(\and(\in(_, X), Y)) = X : get_patterns(Y)
def _extract_args(pattern: Pattern) -> tuple[Pattern, ...]:
match pattern:
case Top():
return ()
case And(ops=(In(right=x), y)):
return (x,) + FunctionRule._get_patterns(y)
case And(ops=(In(left=EVar(), right=arg), rest)):
return (arg,) + FunctionRule._extract_args(rest)
case _:
raise AssertionError()
raise ValueError(f'Cannot extract argument list from pattern: {pattern.text}')


class SimpliRule(Rule, Generic[P], ABC):
lhs: P

@staticmethod
def _extract_rhs(axiom: Axiom) -> tuple[App, Pattern, Pattern | None]:
# Case 0 of get_right_hand_side
def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None]:
match axiom.pattern:
case Implies(right=Equals(left=App() as app, right=And(ops=(rhs, Top() | Equals() as _ens)))):
pass
case Implies(left=_req, right=Equals(left=lhs, right=_rhs)):
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
if not isinstance(lhs, lhs_type):
raise ValueError(f'Invalid LHS type from simplification axiom: {axiom.text}')
return lhs, rhs, req, ens
case _:
raise ValueError(f'Cannot extract RHS from axiom: {axiom.text}')
ens = _extract_ensures(_ens)
return app, rhs, ens
raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}')


@final
@dataclass
class SimpliRule(Rule):
lhs: Pattern
@dataclass(frozen=True)
class AppRule(SimpliRule[App]):
lhs: App
rhs: Pattern
req: Pattern | None
ens: Pattern | None
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> AppRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, App)
priority = _extract_priority(axiom)
return AppRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
priority=priority,
)


@final
@dataclass(frozen=True)
class CeilRule(SimpliRule):
lhs: Ceil
rhs: Pattern
req: Pattern | None
ens: Pattern | None
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> SimpliRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom)
def from_axiom(axiom: Axiom) -> CeilRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, Ceil)
priority = _extract_priority(axiom)
return SimpliRule(
return CeilRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
priority=priority,
)


@final
@dataclass(frozen=True)
class EqualsRule(SimpliRule):
lhs: Equals
rhs: Pattern
req: Pattern | None
ens: Pattern | None
priority: int

@staticmethod
def _extract(axiom: Axiom) -> tuple[Pattern, Pattern, Pattern | None, Pattern | None]:
req: Pattern | None = None
# Cases 11-12 of get_left_hand_side
# Case 0 of get_right_hand_side
match axiom.pattern:
case Implies(left=Top(), right=Equals(left=lhs, right=And(ops=(rhs, Top() | Equals() as _ens)))):
pass
case Implies(left=Equals(left=req), right=Equals(left=lhs, right=And(ops=(rhs, Top() | Equals() as _ens)))):
pass
case Implies(right=Equals(left=Ceil())):
raise ValueError(f'Axiom is a ceil rule: {axiom.text}')
case _:
raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}')
ens = _extract_ensures(_ens)
return lhs, rhs, req, ens
def from_axiom(axiom: Axiom) -> EqualsRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, Equals)
if not isinstance(lhs, Equals):
raise ValueError(f'Cannot extract LHS as Equals from axiom: {axiom.text}')
priority = _extract_priority(axiom)
return EqualsRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
priority=priority,
)


def _extract_rhs(pattern: Pattern) -> tuple[Pattern, Pattern | None]:
match pattern:
case And(ops=(rhs, _ens)):
return rhs, _extract_condition(_ens)
case _:
raise ValueError(f'Cannot extract RHS from pattern: {pattern.text}')


def _extract_ensures(ens: Top | Equals | None) -> Pattern | None:
match ens:
def _extract_condition(pattern: Pattern) -> Pattern | None:
match pattern:
case Top():
return None
case Equals(left=res):
return res
case Equals(left=cond, right=DV(SortApp('SortBool'), String('true'))):
return cond
case _:
raise AssertionError()
raise ValueError(f'Cannot extract condition from pattern: {pattern.text}')


def _extract_uid(axiom: Axiom) -> str:
Expand Down
6 changes: 4 additions & 2 deletions pyk/src/tests/integration/kore/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.fixture(scope='module')
def definition(kompile: Kompiler) -> Definition:
main_file = K_FILES / 'imp.k'
definition_dir = kompile(main_file=main_file)
definition_dir = kompile(main_file=main_file, backend='haskell')
kore_file = definition_dir / 'definition.kore'
kore_text = kore_file.read_text()
definition = KoreParser(kore_text).definition()
Expand All @@ -33,4 +33,6 @@ def test_extract_all(definition: Definition) -> None:
cnt = Counter(type(rule).__name__ for rule in rules)
assert cnt['RewriteRule']
assert cnt['FunctionRule']
assert cnt['SimpliRule']
assert cnt['AppRule']
assert cnt['CeilRule']
assert cnt['EqualsRule']

0 comments on commit e6a0d1d

Please sign in to comment.