Skip to content

Commit

Permalink
Merge pull request #126 from lean-dojo/dev
Browse files Browse the repository at this point in the history
fix lean 4 ast
  • Loading branch information
Kaiyu Yang authored Jan 10, 2024
2 parents 6e6a4e7 + 9a7d31d commit 4718a10
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 43 deletions.
121 changes: 78 additions & 43 deletions src/lean_dojo/data_extraction/ast/lean4/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def get_ident(self) -> str:

@dataclass(frozen=True)
class LeanElabCommandCommandIrreducibleDefNode4(Node4):
name: str
name: Optional[str]
full_name: Optional[str] = None

@classmethod
Expand All @@ -330,15 +330,19 @@ def from_data(
start, end = None, None
children = _parse_children(node_data, lean_file)

assert isinstance(children[0], CommandDeclmodifiersNode4)
assert (
isinstance(children[1], AtomNode4) and children[1].val == "irreducible_def"
)
declid_node = children[2]
assert isinstance(declid_node, CommandDeclidNode4)
ident_node = declid_node.children[0]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val
if isinstance(children[0], CommandDeclmodifiersAntiquotNode4):
name = None
else:
assert isinstance(children[0], CommandDeclmodifiersNode4)
assert (
isinstance(children[1], AtomNode4)
and children[1].val == "irreducible_def"
)
declid_node = children[2]
assert isinstance(declid_node, CommandDeclidNode4)
ident_node = declid_node.children[0]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val

return cls(lean_file, start, end, children, name)

Expand Down Expand Up @@ -567,7 +571,7 @@ def from_data(

@dataclass(frozen=True)
class CommandInductiveNode4(Node4):
name: str
name: Optional[str]

@classmethod
def from_data(
Expand All @@ -578,11 +582,15 @@ def from_data(
children = _parse_children(node_data, lean_file)

assert isinstance(children[0], AtomNode4) and children[0].val == "inductive"
assert isinstance(children[1], CommandDeclidNode4)
decl_id_node = children[1]
ident_node = decl_id_node.children[0]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val

if isinstance(children[1], CommandDeclidAntiquotNode4):
name = None
else:
assert isinstance(children[1], CommandDeclidNode4)
decl_id_node = children[1]
ident_node = decl_id_node.children[0]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val

return cls(lean_file, start, end, children, name)

Expand All @@ -607,11 +615,15 @@ def from_data(
isinstance(children[0].children[1], AtomNode4)
and children[0].children[1].val == "inductive"
)
assert isinstance(children[1], CommandDeclidNode4)
decl_id_node = children[1]
ident_node = decl_id_node.children[0]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val

if isinstance(children[1], CommandDeclidAntiquotNode4):
name = None
else:
assert isinstance(children[1], CommandDeclidNode4)
decl_id_node = children[1]
ident_node = decl_id_node.children[0]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val

return cls(lean_file, start, end, children, name)

Expand Down Expand Up @@ -671,9 +683,12 @@ def from_data(

assert isinstance(children[0], CommandDeclmodifiersNode4)
assert isinstance(children[1], AtomNode4) and children[1].val == "alias"
ident_node = children[2]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val
if isinstance(children[2], IdentAntiquotNode4):
name = None
else:
ident_node = children[2]
assert isinstance(ident_node, IdentNode4)
name = ident_node.val
return cls(lean_file, start, end, children, name)


Expand Down Expand Up @@ -998,26 +1013,27 @@ def from_data(
assert isinstance(ident_node, IdentAntiquotNode4)
name = ident_node.get_ident()

assert isinstance(children[2], CommandDeclsigNode4)
decl_val_node = children[3]
assert type(decl_val_node) in (
CommandDeclvalsimpleNode4,
CommandDeclvaleqnsNode4,
CommandWherestructinstNode4,
)

if isinstance(decl_val_node, CommandDeclvalsimpleNode4):
assert (
isinstance(decl_val_node.children[0], AtomNode4)
and decl_val_node.children[0].val == ":="
)
assert isinstance(decl_val_node.children[2], NullNode4)
elif isinstance(decl_val_node, CommandWherestructinstNode4):
assert (
isinstance(decl_val_node.children[0], AtomNode4)
and decl_val_node.children[0].val == "where"
if not isinstance(children[1], CommandDeclidAntiquotNode4):
assert isinstance(children[2], CommandDeclsigNode4)
decl_val_node = children[3]
assert type(decl_val_node) in (
CommandDeclvalsimpleNode4,
CommandDeclvaleqnsNode4,
CommandWherestructinstNode4,
)
assert isinstance(decl_val_node.children[2], NullNode4)

if isinstance(decl_val_node, CommandDeclvalsimpleNode4):
assert (
isinstance(decl_val_node.children[0], AtomNode4)
and decl_val_node.children[0].val == ":="
)
assert isinstance(decl_val_node.children[2], NullNode4)
elif isinstance(decl_val_node, CommandWherestructinstNode4):
assert (
isinstance(decl_val_node.children[0], AtomNode4)
and decl_val_node.children[0].val == "where"
)
assert isinstance(decl_val_node.children[2], NullNode4)

return cls(lean_file, start, end, children, name)

Expand Down Expand Up @@ -1067,6 +1083,7 @@ def from_data(
assert len(children) == 1 and type(children[0]) in (
TacticTacticseq1IndentedNode4,
TacticTacticseqbracketedNode4,
TacticTacticSeq1IndentedAntiquotNode4,
)
return cls(lean_file, start, end, children)

Expand Down Expand Up @@ -1101,6 +1118,24 @@ def get_tactic_nodes(
)


@dataclass(frozen=True)
class TacticTacticSeq1IndentedAntiquotNode4(Node4):
@classmethod
def from_data(
cls, node_data: Dict[str, Any], lean_file: LeanFile
) -> "TacticTacticSeq1IndentedAntiquotNode4":
assert node_data["info"] == "none"
start, end = None, None
children = _parse_children(node_data, lean_file)
assert len(children) == 1 and isinstance(children[0], NullNode4)
return cls(lean_file, start, end, children)

def get_tactic_nodes(
self, atomic_only: bool = False
) -> Generator[Node4, None, None]:
return


@dataclass(frozen=True)
class TacticTacticseqbracketedNode4(Node4):
state_before: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions src/lean_dojo/data_extraction/build_lean4_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def get_lean_version() -> str:
def check_files(packages_path: str, no_deps: bool) -> None:
"""Check if all *.lean files have been processed to produce *.ast.json and *.dep_paths files."""
cwd = Path.cwd()
packages_path = cwd / packages_path
jsons = {
p.with_suffix("").with_suffix("")
for p in cwd.glob("**/build/ir/**/*.ast.json")
Expand Down

0 comments on commit 4718a10

Please sign in to comment.