Skip to content

Commit

Permalink
Generate Lean 4 type definitions from a KORE definition (#4717)
Browse files Browse the repository at this point in the history
* Add a prelude for basic primitive sorts
* Generate and `inductive` for each constructed sort
* Generate an `abbrev` for each collection sort
  • Loading branch information
tothtamas28 authored Jan 7, 2025
1 parent 06ffc88 commit 8c40191
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 4 deletions.
10 changes: 10 additions & 0 deletions pyk/src/pyk/k2lean4/Prelude.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
abbrev SortBool : Type := Int
abbrev SortBytes: Type := ByteArray
abbrev SortId : Type := String
abbrev SortInt : Type := Int
abbrev SortString : Type := String
abbrev SortStringBuffer : Type := String

abbrev ListHook (E : Type) : Type := List E
abbrev MapHook (K : Type) (V : Type) : Type := List (K × V)
abbrev SetHook (E : Type) : Type := List E
Empty file added pyk/src/pyk/k2lean4/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions pyk/src/pyk/k2lean4/k2lean4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from ..kore.internal import CollectionKind
from ..kore.syntax import SortApp
from ..utils import check_type
from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, Term

if TYPE_CHECKING:
from ..kore.internal import KoreDefn
from .model import Command


@dataclass(frozen=True)
class K2Lean4:
defn: KoreDefn

def sort_module(self) -> Module:
commands = []
commands += self._inductives()
commands += self._collections()
return Module(commands=commands)

def _inductives(self) -> list[Command]:
def is_inductive(sort: str) -> bool:
decl = self.defn.sorts[sort]
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key

sorts = sorted(sort for sort in self.defn.sorts if is_inductive(sort))
return [self._inductive(sort) for sort in sorts]

def _inductive(self, sort: str) -> Inductive:
subsorts = sorted(self.defn.subsorts.get(sort, ()))
symbols = sorted(self.defn.constructors.get(sort, ()))
ctors: list[Ctor] = []
ctors.extend(self._inj_ctor(sort, subsort) for subsort in subsorts)
ctors.extend(self._symbol_ctor(sort, symbol) for symbol in symbols)
return Inductive(sort, Signature((), Term('Type')), ctors=ctors)

def _inj_ctor(self, sort: str, subsort: str) -> Ctor:
return Ctor(f'inj_{subsort}', Signature((ExplBinder(('x',), Term(subsort)),), Term(sort)))

def _symbol_ctor(self, sort: str, symbol: str) -> Ctor:
param_sorts = (
check_type(sort, SortApp).name for sort in self.defn.symbols[symbol].param_sorts
) # TODO eliminate check_type
binders = tuple(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts))
symbol = symbol.replace('-', '_')
return Ctor(symbol, Signature(binders, Term(sort)))

def _collections(self) -> list[Command]:
return [self._collection(sort) for sort in sorted(self.defn.collections)]

def _collection(self, sort: str) -> Abbrev:
coll = self.defn.collections[sort]
elem = self.defn.symbols[coll.element]
sorts = ' '.join(check_type(sort, SortApp).name for sort in elem.param_sorts) # TODO eliminate check_type
assert sorts
match coll.kind:
case CollectionKind.LIST:
val = Term(f'ListHook {sorts}')
case CollectionKind.MAP:
val = Term(f'MapHook {sorts}')
case CollectionKind.SET:
val = Term(f'SetHook {sorts}')
return Abbrev(sort, val, Signature((), Term('Type')))
306 changes: 306 additions & 0 deletions pyk/src/pyk/k2lean4/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, final

if TYPE_CHECKING:
from collections.abc import Iterable


def indent(text: str, n: int) -> str:
indent = n * ' '
res = []
for line in text.splitlines():
res.append(f'{indent}{line}' if line else '')
return '\n'.join(res)


@final
@dataclass(frozen=True)
class Module:
commands: tuple[Command, ...]

def __init__(self, commands: Iterable[Command] | None = None):
commands = tuple(commands) if commands is not None else ()
object.__setattr__(self, 'commands', commands)

def __str__(self) -> str:
return '\n\n'.join(str(command) for command in self.commands)


class Command(ABC): ...


class Mutual(Command):
commands: tuple[Command, ...]

def __init__(self, commands: Iterable[Command] | None = None):
commands = tuple(commands) if commands is not None else ()
object.__setattr__(self, 'commands', commands)

def __str__(self) -> str:
commands = '\n\n'.join(indent(str(command), 2) for command in self.commands)
return f'mutual\n{commands}\nend'


class Declaration(Command, ABC):
modifiers: Modifiers | None


@final
@dataclass
class Abbrev(Declaration):
ident: DeclId
val: Term # declVal
signature: Signature | None
modifiers: Modifiers | None

def __init__(
self,
ident: str | DeclId,
val: Term,
signature: Signature | None = None,
modifiers: Modifiers | None = None,
):
ident = DeclId(ident) if isinstance(ident, str) else ident
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'val', val)
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
signature = f' {self.signature}' if self.signature else ''
return f'{modifiers} abbrev {self.ident}{signature} := {self.val}'


@final
@dataclass(frozen=True)
class Inductive(Declaration):
ident: DeclId
signature: Signature | None
ctors: tuple[Ctor, ...]
deriving: tuple[str, ...]
modifiers: Modifiers | None

def __init__(
self,
ident: str | DeclId,
signature: Signature | None = None,
ctors: Iterable[Ctor] | None = None,
deriving: Iterable[str] | None = None,
modifiers: Modifiers | None = None,
):
ident = DeclId(ident) if isinstance(ident, str) else ident
ctors = tuple(ctors) if ctors is not None else ()
deriving = tuple(deriving) if deriving is not None else ()
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'ctors', ctors)
object.__setattr__(self, 'deriving', deriving)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
signature = f' {self.signature}' if self.signature else ''
where = ' where' if self.ctors else ''
deriving = ', '.join(self.deriving)

lines = []
lines.append(f'{modifiers}inductive {self.ident}{signature}{where}')
for ctor in self.ctors:
lines.append(f' | {ctor}')
if deriving:
lines.append(f' deriving {deriving}')
return '\n'.join(lines)


@final
@dataclass(frozen=True)
class DeclId:
val: str
uvars: tuple[str, ...]

def __init__(self, val: str, uvars: Iterable[str] | None = None):
uvars = tuple(uvars) if uvars is not None else ()
object.__setattr__(self, 'val', val)
object.__setattr__(self, 'uvars', uvars)

def __str__(self) -> str:
uvars = ', '.join(self.uvars)
uvars = '.{' + uvars + '}' if uvars else ''
return f'{self.val}{uvars}'


@final
@dataclass(frozen=True)
class Ctor:
ident: str
signature: Signature | None = None
modifiers: Modifiers | None = None

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
signature = f' {self.signature}' if self.signature else ''
return f'{modifiers}{self.ident}{signature}'


@final
@dataclass(frozen=True)
class Signature:
binders: tuple[Binder, ...]
ty: Term | None

def __init__(self, binders: Iterable[Binder] | None = None, ty: Term | None = None):
binders = tuple(binders) if binders is not None else ()
object.__setattr__(self, 'binders', binders)
object.__setattr__(self, 'ty', ty)

def __str__(self) -> str:
binders = ' '.join(str(binder) for binder in self.binders)
sep = ' ' if self.binders else ''
ty = f'{sep}: {self.ty}' if self.ty else ''
return f'{binders}{ty}'


class Binder(ABC): ...


class BracketBinder(Binder, ABC): ...


@final
@dataclass(frozen=True)
class ExplBinder(BracketBinder):
idents: tuple[str, ...]
ty: Term | None

def __init__(self, idents: Iterable[str], ty: Term | None = None):
object.__setattr__(self, 'idents', tuple(idents))
object.__setattr__(self, 'ty', ty)

def __str__(self) -> str:
idents = ' '.join(self.idents)
ty = '' if self.ty is None else f' : {self.ty}'
return f'({idents}{ty})'


@final
@dataclass(frozen=True)
class ImplBinder(BracketBinder):
idents: tuple[str, ...]
ty: Term | None
strict: bool

def __init__(self, idents: Iterable[str], ty: Term | None = None, *, strict: bool | None = None):
object.__setattr__(self, 'idents', tuple(idents))
object.__setattr__(self, 'ty', ty)
object.__setattr__(self, 'strict', bool(strict))

def __str__(self) -> str:
ldelim, rdelim = ['⦃', '⦄'] if self.strict else ['{', '}']
idents = ' '.join(self.idents)
ty = '' if self.ty is None else f' : {self.ty}'
return f'{ldelim}{idents}{ty}{rdelim}'


@final
@dataclass(frozen=True)
class InstBinder(BracketBinder):
ty: Term
ident: str | None

def __init__(self, ty: Term, ident: str | None = None):
object.__setattr__(self, 'ty', ty)
object.__setattr__(self, 'ident', ident)

def __str__(self) -> str:
ident = f'{self.ident} : ' if self.ident else ''
return f'[{ident}{self.ty}]'


@final
@dataclass(frozen=True)
class Term:
term: str # TODO: refine

def __str__(self) -> str:
return self.term


@final
@dataclass(frozen=True)
class Modifiers:
attrs: tuple[Attr, ...]
visibility: Visibility | None
noncomputable: bool
unsafe: bool
totality: Totality | None

def __init__(
self,
*,
attrs: Iterable[str | Attr] | None = None,
visibility: str | Visibility | None = None,
noncomputable: bool | None = None,
unsafe: bool | None = None,
totality: str | Totality | None = None,
):
attrs = tuple(Attr(attr) if isinstance(attr, str) else attr for attr in attrs) if attrs is not None else ()
visibility = Visibility(visibility) if isinstance(visibility, str) else visibility
noncomputable = bool(noncomputable)
unsafe = bool(unsafe)
totality = Totality(totality) if isinstance(totality, str) else totality
object.__setattr__(self, 'attrs', attrs)
object.__setattr__(self, 'visibility', visibility)
object.__setattr__(self, 'noncomputable', noncomputable)
object.__setattr__(self, 'unsafe', unsafe)
object.__setattr__(self, 'totality', totality)

def __str__(self) -> str:
chunks = []
if self.attrs:
attrs = ', '.join(str(attr) for attr in self.attrs)
chunks.append(f'@[{attrs}]')
if self.visibility:
chunks.append(self.visibility.value)
if self.noncomputable:
chunks.append('noncomputable')
if self.unsafe:
chunks.append('unsafe')
if self.totality:
chunks.append(self.totality.value)
return ' '.join(chunks)


@final
@dataclass(frozen=True)
class Attr:
attr: str
kind: AttrKind | None

def __init__(self, attr: str, kind: AttrKind | None = None):
object.__setattr__(self, 'attr', attr)
object.__setattr__(self, 'kind', kind)

def __str__(self) -> str:
if self.kind:
return f'{self.kind.value} {self.attr}'
return self.attr


class AttrKind(Enum):
SCOPED = 'scoped'
LOCAL = 'local'


class Visibility(Enum):
PRIVATE = 'private'
PROTECTED = 'protected'


class Totality(Enum):
PARTIAL = 'partial'
NONREC = 'nonrec'
Loading

0 comments on commit 8c40191

Please sign in to comment.