diff --git a/src/lean_dojo/data_extraction/ast/lean4/node.py b/src/lean_dojo/data_extraction/ast/lean4/node.py index 78130c16..4236d1f9 100644 --- a/src/lean_dojo/data_extraction/ast/lean4/node.py +++ b/src/lean_dojo/data_extraction/ast/lean4/node.py @@ -319,7 +319,7 @@ def get_ident(self) -> str: @dataclass(frozen=True) class LeanElabCommandCommandIrreducibleDefNode4(Node4): - name: str + name: Optional[str] full_name: Optional[str] = None @classmethod @@ -330,15 +330,19 @@ def from_data( start, end = None, None children = _parse_children(node_data, lean_file) - assert isinstance(children[0], CommandDeclmodifiersNode4) - assert ( - isinstance(children[1], AtomNode4) and children[1].val == "irreducible_def" - ) - declid_node = children[2] - assert isinstance(declid_node, CommandDeclidNode4) - ident_node = declid_node.children[0] - assert isinstance(ident_node, IdentNode4) - name = ident_node.val + if isinstance(children[0], CommandDeclmodifiersAntiquotNode4): + name = None + else: + assert isinstance(children[0], CommandDeclmodifiersNode4) + assert ( + isinstance(children[1], AtomNode4) + and children[1].val == "irreducible_def" + ) + declid_node = children[2] + assert isinstance(declid_node, CommandDeclidNode4) + ident_node = declid_node.children[0] + assert isinstance(ident_node, IdentNode4) + name = ident_node.val return cls(lean_file, start, end, children, name) @@ -567,7 +571,7 @@ def from_data( @dataclass(frozen=True) class CommandInductiveNode4(Node4): - name: str + name: Optional[str] @classmethod def from_data( @@ -578,11 +582,15 @@ def from_data( children = _parse_children(node_data, lean_file) assert isinstance(children[0], AtomNode4) and children[0].val == "inductive" - assert isinstance(children[1], CommandDeclidNode4) - decl_id_node = children[1] - ident_node = decl_id_node.children[0] - assert isinstance(ident_node, IdentNode4) - name = ident_node.val + + if isinstance(children[1], CommandDeclidAntiquotNode4): + name = None + else: + assert isinstance(children[1], CommandDeclidNode4) + decl_id_node = children[1] + ident_node = decl_id_node.children[0] + assert isinstance(ident_node, IdentNode4) + name = ident_node.val return cls(lean_file, start, end, children, name) @@ -607,11 +615,15 @@ def from_data( isinstance(children[0].children[1], AtomNode4) and children[0].children[1].val == "inductive" ) - assert isinstance(children[1], CommandDeclidNode4) - decl_id_node = children[1] - ident_node = decl_id_node.children[0] - assert isinstance(ident_node, IdentNode4) - name = ident_node.val + + if isinstance(children[1], CommandDeclidAntiquotNode4): + name = None + else: + assert isinstance(children[1], CommandDeclidNode4) + decl_id_node = children[1] + ident_node = decl_id_node.children[0] + assert isinstance(ident_node, IdentNode4) + name = ident_node.val return cls(lean_file, start, end, children, name) @@ -671,9 +683,12 @@ def from_data( assert isinstance(children[0], CommandDeclmodifiersNode4) assert isinstance(children[1], AtomNode4) and children[1].val == "alias" - ident_node = children[2] - assert isinstance(ident_node, IdentNode4) - name = ident_node.val + if isinstance(children[2], IdentAntiquotNode4): + name = None + else: + ident_node = children[2] + assert isinstance(ident_node, IdentNode4) + name = ident_node.val return cls(lean_file, start, end, children, name) @@ -998,26 +1013,27 @@ def from_data( assert isinstance(ident_node, IdentAntiquotNode4) name = ident_node.get_ident() - assert isinstance(children[2], CommandDeclsigNode4) - decl_val_node = children[3] - assert type(decl_val_node) in ( - CommandDeclvalsimpleNode4, - CommandDeclvaleqnsNode4, - CommandWherestructinstNode4, - ) - - if isinstance(decl_val_node, CommandDeclvalsimpleNode4): - assert ( - isinstance(decl_val_node.children[0], AtomNode4) - and decl_val_node.children[0].val == ":=" - ) - assert isinstance(decl_val_node.children[2], NullNode4) - elif isinstance(decl_val_node, CommandWherestructinstNode4): - assert ( - isinstance(decl_val_node.children[0], AtomNode4) - and decl_val_node.children[0].val == "where" + if not isinstance(children[1], CommandDeclidAntiquotNode4): + assert isinstance(children[2], CommandDeclsigNode4) + decl_val_node = children[3] + assert type(decl_val_node) in ( + CommandDeclvalsimpleNode4, + CommandDeclvaleqnsNode4, + CommandWherestructinstNode4, ) - assert isinstance(decl_val_node.children[2], NullNode4) + + if isinstance(decl_val_node, CommandDeclvalsimpleNode4): + assert ( + isinstance(decl_val_node.children[0], AtomNode4) + and decl_val_node.children[0].val == ":=" + ) + assert isinstance(decl_val_node.children[2], NullNode4) + elif isinstance(decl_val_node, CommandWherestructinstNode4): + assert ( + isinstance(decl_val_node.children[0], AtomNode4) + and decl_val_node.children[0].val == "where" + ) + assert isinstance(decl_val_node.children[2], NullNode4) return cls(lean_file, start, end, children, name) @@ -1067,6 +1083,7 @@ def from_data( assert len(children) == 1 and type(children[0]) in ( TacticTacticseq1IndentedNode4, TacticTacticseqbracketedNode4, + TacticTacticSeq1IndentedAntiquotNode4, ) return cls(lean_file, start, end, children) @@ -1101,6 +1118,24 @@ def get_tactic_nodes( ) +@dataclass(frozen=True) +class TacticTacticSeq1IndentedAntiquotNode4(Node4): + @classmethod + def from_data( + cls, node_data: Dict[str, Any], lean_file: LeanFile + ) -> "TacticTacticSeq1IndentedAntiquotNode4": + assert node_data["info"] == "none" + start, end = None, None + children = _parse_children(node_data, lean_file) + assert len(children) == 1 and isinstance(children[0], NullNode4) + return cls(lean_file, start, end, children) + + def get_tactic_nodes( + self, atomic_only: bool = False + ) -> Generator[Node4, None, None]: + return + + @dataclass(frozen=True) class TacticTacticseqbracketedNode4(Node4): state_before: Optional[str] = None diff --git a/src/lean_dojo/data_extraction/build_lean4_repo.py b/src/lean_dojo/data_extraction/build_lean4_repo.py index 1e11a732..b0468cf4 100644 --- a/src/lean_dojo/data_extraction/build_lean4_repo.py +++ b/src/lean_dojo/data_extraction/build_lean4_repo.py @@ -107,6 +107,7 @@ def get_lean_version() -> str: def check_files(packages_path: str, no_deps: bool) -> None: """Check if all *.lean files have been processed to produce *.ast.json and *.dep_paths files.""" cwd = Path.cwd() + packages_path = cwd / packages_path jsons = { p.with_suffix("").with_suffix("") for p in cwd.glob("**/build/ir/**/*.ast.json")