-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate Lean 4 type definitions from a KORE definition (#4717)
* Add a prelude for basic primitive sorts * Generate and `inductive` for each constructed sort * Generate an `abbrev` for each collection sort
- Loading branch information
1 parent
06ffc88
commit 8c40191
Showing
5 changed files
with
453 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
abbrev SortBool : Type := Int | ||
abbrev SortBytes: Type := ByteArray | ||
abbrev SortId : Type := String | ||
abbrev SortInt : Type := Int | ||
abbrev SortString : Type := String | ||
abbrev SortStringBuffer : Type := String | ||
|
||
abbrev ListHook (E : Type) : Type := List E | ||
abbrev MapHook (K : Type) (V : Type) : Type := List (K × V) | ||
abbrev SetHook (E : Type) : Type := List E |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
from ..kore.internal import CollectionKind | ||
from ..kore.syntax import SortApp | ||
from ..utils import check_type | ||
from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, Term | ||
|
||
if TYPE_CHECKING: | ||
from ..kore.internal import KoreDefn | ||
from .model import Command | ||
|
||
|
||
@dataclass(frozen=True) | ||
class K2Lean4: | ||
defn: KoreDefn | ||
|
||
def sort_module(self) -> Module: | ||
commands = [] | ||
commands += self._inductives() | ||
commands += self._collections() | ||
return Module(commands=commands) | ||
|
||
def _inductives(self) -> list[Command]: | ||
def is_inductive(sort: str) -> bool: | ||
decl = self.defn.sorts[sort] | ||
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key | ||
|
||
sorts = sorted(sort for sort in self.defn.sorts if is_inductive(sort)) | ||
return [self._inductive(sort) for sort in sorts] | ||
|
||
def _inductive(self, sort: str) -> Inductive: | ||
subsorts = sorted(self.defn.subsorts.get(sort, ())) | ||
symbols = sorted(self.defn.constructors.get(sort, ())) | ||
ctors: list[Ctor] = [] | ||
ctors.extend(self._inj_ctor(sort, subsort) for subsort in subsorts) | ||
ctors.extend(self._symbol_ctor(sort, symbol) for symbol in symbols) | ||
return Inductive(sort, Signature((), Term('Type')), ctors=ctors) | ||
|
||
def _inj_ctor(self, sort: str, subsort: str) -> Ctor: | ||
return Ctor(f'inj_{subsort}', Signature((ExplBinder(('x',), Term(subsort)),), Term(sort))) | ||
|
||
def _symbol_ctor(self, sort: str, symbol: str) -> Ctor: | ||
param_sorts = ( | ||
check_type(sort, SortApp).name for sort in self.defn.symbols[symbol].param_sorts | ||
) # TODO eliminate check_type | ||
binders = tuple(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts)) | ||
symbol = symbol.replace('-', '_') | ||
return Ctor(symbol, Signature(binders, Term(sort))) | ||
|
||
def _collections(self) -> list[Command]: | ||
return [self._collection(sort) for sort in sorted(self.defn.collections)] | ||
|
||
def _collection(self, sort: str) -> Abbrev: | ||
coll = self.defn.collections[sort] | ||
elem = self.defn.symbols[coll.element] | ||
sorts = ' '.join(check_type(sort, SortApp).name for sort in elem.param_sorts) # TODO eliminate check_type | ||
assert sorts | ||
match coll.kind: | ||
case CollectionKind.LIST: | ||
val = Term(f'ListHook {sorts}') | ||
case CollectionKind.MAP: | ||
val = Term(f'MapHook {sorts}') | ||
case CollectionKind.SET: | ||
val = Term(f'SetHook {sorts}') | ||
return Abbrev(sort, val, Signature((), Term('Type'))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import TYPE_CHECKING, final | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterable | ||
|
||
|
||
def indent(text: str, n: int) -> str: | ||
indent = n * ' ' | ||
res = [] | ||
for line in text.splitlines(): | ||
res.append(f'{indent}{line}' if line else '') | ||
return '\n'.join(res) | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Module: | ||
commands: tuple[Command, ...] | ||
|
||
def __init__(self, commands: Iterable[Command] | None = None): | ||
commands = tuple(commands) if commands is not None else () | ||
object.__setattr__(self, 'commands', commands) | ||
|
||
def __str__(self) -> str: | ||
return '\n\n'.join(str(command) for command in self.commands) | ||
|
||
|
||
class Command(ABC): ... | ||
|
||
|
||
class Mutual(Command): | ||
commands: tuple[Command, ...] | ||
|
||
def __init__(self, commands: Iterable[Command] | None = None): | ||
commands = tuple(commands) if commands is not None else () | ||
object.__setattr__(self, 'commands', commands) | ||
|
||
def __str__(self) -> str: | ||
commands = '\n\n'.join(indent(str(command), 2) for command in self.commands) | ||
return f'mutual\n{commands}\nend' | ||
|
||
|
||
class Declaration(Command, ABC): | ||
modifiers: Modifiers | None | ||
|
||
|
||
@final | ||
@dataclass | ||
class Abbrev(Declaration): | ||
ident: DeclId | ||
val: Term # declVal | ||
signature: Signature | None | ||
modifiers: Modifiers | None | ||
|
||
def __init__( | ||
self, | ||
ident: str | DeclId, | ||
val: Term, | ||
signature: Signature | None = None, | ||
modifiers: Modifiers | None = None, | ||
): | ||
ident = DeclId(ident) if isinstance(ident, str) else ident | ||
object.__setattr__(self, 'ident', ident) | ||
object.__setattr__(self, 'val', val) | ||
object.__setattr__(self, 'signature', signature) | ||
object.__setattr__(self, 'modifiers', modifiers) | ||
|
||
def __str__(self) -> str: | ||
modifiers = f'{self.modifiers} ' if self.modifiers else '' | ||
signature = f' {self.signature}' if self.signature else '' | ||
return f'{modifiers} abbrev {self.ident}{signature} := {self.val}' | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Inductive(Declaration): | ||
ident: DeclId | ||
signature: Signature | None | ||
ctors: tuple[Ctor, ...] | ||
deriving: tuple[str, ...] | ||
modifiers: Modifiers | None | ||
|
||
def __init__( | ||
self, | ||
ident: str | DeclId, | ||
signature: Signature | None = None, | ||
ctors: Iterable[Ctor] | None = None, | ||
deriving: Iterable[str] | None = None, | ||
modifiers: Modifiers | None = None, | ||
): | ||
ident = DeclId(ident) if isinstance(ident, str) else ident | ||
ctors = tuple(ctors) if ctors is not None else () | ||
deriving = tuple(deriving) if deriving is not None else () | ||
object.__setattr__(self, 'ident', ident) | ||
object.__setattr__(self, 'signature', signature) | ||
object.__setattr__(self, 'ctors', ctors) | ||
object.__setattr__(self, 'deriving', deriving) | ||
object.__setattr__(self, 'modifiers', modifiers) | ||
|
||
def __str__(self) -> str: | ||
modifiers = f'{self.modifiers} ' if self.modifiers else '' | ||
signature = f' {self.signature}' if self.signature else '' | ||
where = ' where' if self.ctors else '' | ||
deriving = ', '.join(self.deriving) | ||
|
||
lines = [] | ||
lines.append(f'{modifiers}inductive {self.ident}{signature}{where}') | ||
for ctor in self.ctors: | ||
lines.append(f' | {ctor}') | ||
if deriving: | ||
lines.append(f' deriving {deriving}') | ||
return '\n'.join(lines) | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class DeclId: | ||
val: str | ||
uvars: tuple[str, ...] | ||
|
||
def __init__(self, val: str, uvars: Iterable[str] | None = None): | ||
uvars = tuple(uvars) if uvars is not None else () | ||
object.__setattr__(self, 'val', val) | ||
object.__setattr__(self, 'uvars', uvars) | ||
|
||
def __str__(self) -> str: | ||
uvars = ', '.join(self.uvars) | ||
uvars = '.{' + uvars + '}' if uvars else '' | ||
return f'{self.val}{uvars}' | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Ctor: | ||
ident: str | ||
signature: Signature | None = None | ||
modifiers: Modifiers | None = None | ||
|
||
def __str__(self) -> str: | ||
modifiers = f'{self.modifiers} ' if self.modifiers else '' | ||
signature = f' {self.signature}' if self.signature else '' | ||
return f'{modifiers}{self.ident}{signature}' | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Signature: | ||
binders: tuple[Binder, ...] | ||
ty: Term | None | ||
|
||
def __init__(self, binders: Iterable[Binder] | None = None, ty: Term | None = None): | ||
binders = tuple(binders) if binders is not None else () | ||
object.__setattr__(self, 'binders', binders) | ||
object.__setattr__(self, 'ty', ty) | ||
|
||
def __str__(self) -> str: | ||
binders = ' '.join(str(binder) for binder in self.binders) | ||
sep = ' ' if self.binders else '' | ||
ty = f'{sep}: {self.ty}' if self.ty else '' | ||
return f'{binders}{ty}' | ||
|
||
|
||
class Binder(ABC): ... | ||
|
||
|
||
class BracketBinder(Binder, ABC): ... | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class ExplBinder(BracketBinder): | ||
idents: tuple[str, ...] | ||
ty: Term | None | ||
|
||
def __init__(self, idents: Iterable[str], ty: Term | None = None): | ||
object.__setattr__(self, 'idents', tuple(idents)) | ||
object.__setattr__(self, 'ty', ty) | ||
|
||
def __str__(self) -> str: | ||
idents = ' '.join(self.idents) | ||
ty = '' if self.ty is None else f' : {self.ty}' | ||
return f'({idents}{ty})' | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class ImplBinder(BracketBinder): | ||
idents: tuple[str, ...] | ||
ty: Term | None | ||
strict: bool | ||
|
||
def __init__(self, idents: Iterable[str], ty: Term | None = None, *, strict: bool | None = None): | ||
object.__setattr__(self, 'idents', tuple(idents)) | ||
object.__setattr__(self, 'ty', ty) | ||
object.__setattr__(self, 'strict', bool(strict)) | ||
|
||
def __str__(self) -> str: | ||
ldelim, rdelim = ['⦃', '⦄'] if self.strict else ['{', '}'] | ||
idents = ' '.join(self.idents) | ||
ty = '' if self.ty is None else f' : {self.ty}' | ||
return f'{ldelim}{idents}{ty}{rdelim}' | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class InstBinder(BracketBinder): | ||
ty: Term | ||
ident: str | None | ||
|
||
def __init__(self, ty: Term, ident: str | None = None): | ||
object.__setattr__(self, 'ty', ty) | ||
object.__setattr__(self, 'ident', ident) | ||
|
||
def __str__(self) -> str: | ||
ident = f'{self.ident} : ' if self.ident else '' | ||
return f'[{ident}{self.ty}]' | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Term: | ||
term: str # TODO: refine | ||
|
||
def __str__(self) -> str: | ||
return self.term | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Modifiers: | ||
attrs: tuple[Attr, ...] | ||
visibility: Visibility | None | ||
noncomputable: bool | ||
unsafe: bool | ||
totality: Totality | None | ||
|
||
def __init__( | ||
self, | ||
*, | ||
attrs: Iterable[str | Attr] | None = None, | ||
visibility: str | Visibility | None = None, | ||
noncomputable: bool | None = None, | ||
unsafe: bool | None = None, | ||
totality: str | Totality | None = None, | ||
): | ||
attrs = tuple(Attr(attr) if isinstance(attr, str) else attr for attr in attrs) if attrs is not None else () | ||
visibility = Visibility(visibility) if isinstance(visibility, str) else visibility | ||
noncomputable = bool(noncomputable) | ||
unsafe = bool(unsafe) | ||
totality = Totality(totality) if isinstance(totality, str) else totality | ||
object.__setattr__(self, 'attrs', attrs) | ||
object.__setattr__(self, 'visibility', visibility) | ||
object.__setattr__(self, 'noncomputable', noncomputable) | ||
object.__setattr__(self, 'unsafe', unsafe) | ||
object.__setattr__(self, 'totality', totality) | ||
|
||
def __str__(self) -> str: | ||
chunks = [] | ||
if self.attrs: | ||
attrs = ', '.join(str(attr) for attr in self.attrs) | ||
chunks.append(f'@[{attrs}]') | ||
if self.visibility: | ||
chunks.append(self.visibility.value) | ||
if self.noncomputable: | ||
chunks.append('noncomputable') | ||
if self.unsafe: | ||
chunks.append('unsafe') | ||
if self.totality: | ||
chunks.append(self.totality.value) | ||
return ' '.join(chunks) | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class Attr: | ||
attr: str | ||
kind: AttrKind | None | ||
|
||
def __init__(self, attr: str, kind: AttrKind | None = None): | ||
object.__setattr__(self, 'attr', attr) | ||
object.__setattr__(self, 'kind', kind) | ||
|
||
def __str__(self) -> str: | ||
if self.kind: | ||
return f'{self.kind.value} {self.attr}' | ||
return self.attr | ||
|
||
|
||
class AttrKind(Enum): | ||
SCOPED = 'scoped' | ||
LOCAL = 'local' | ||
|
||
|
||
class Visibility(Enum): | ||
PRIVATE = 'private' | ||
PROTECTED = 'protected' | ||
|
||
|
||
class Totality(Enum): | ||
PARTIAL = 'partial' | ||
NONREC = 'nonrec' |
Oops, something went wrong.