diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 92a21e1..0d4001f 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -5,13 +5,12 @@ import re import os import json -import uuid import toml import time -import shutil import urllib import webbrowser import shutil +from enum import Enum from pathlib import Path from loguru import logger from functools import cache @@ -57,37 +56,44 @@ REPO_CACHE_PREFIX = "repos" -def normalize_url(url: str, repo_type: str = "github") -> str: - if repo_type == "local": - return os.path.abspath(url) # Convert to absolute path if local - return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`. +class RepoType(Enum): + GITHUB = 0 + REMOTE = 1 # Remote but not GitHub. + LOCAL = 2 + +def normalize_url(url: str, repo_type: RepoType = RepoType.GITHUB) -> str: + if repo_type == RepoType.LOCAL: # Convert to absolute path if local. + return os.path.abspath(url) + # Remove trailing `/`. + return _URL_REGEX.fullmatch(url)["url"] -def repo_type_of_url(url: str) -> Union[str, None]: + +def get_repo_type(url: str) -> Optional[RepoType]: """Get the type of the repository. Args: url (str): The URL of the repository. Returns: - str: The type of the repository. + Optional[str]: The type of the repository (None if the repo cannot be found). """ m = _SSH_TO_HTTPS_REGEX.match(url) url = f"https://github.com/{m.group(1)}/{m.group(2)}" if m else url parsed_url = urllib.parse.urlparse(url) if parsed_url.scheme in ["http", "https"]: - # case 1 - GitHub URL + # Case 1 - GitHub URL. if "github.com" in url: if not url.startswith("https://"): logger.warning(f"{url} should start with https://") - return + return None else: - return "github" - # case 2 - remote URL - elif url_exists(url): # not check whether it is a git URL - return "remote" - # case 3 - local path + return RepoType.GITHUB + # Case 2 - remote URL. + elif url_exists(url): # Not check whether it is a git URL + return RepoType.REMOTE + # Case 3 - local path elif is_git_repo(Path(parsed_url.path)): - return "local" + return RepoType.LOCAL logger.warning(f"{url} is not a valid URL") return None @@ -103,11 +109,11 @@ def _split_git_url(url: str) -> Tuple[str, str]: return user_name, repo_name -def _format_dirname(url: str, commit: str) -> str: +def _format_cache_dirname(url: str, commit: str) -> str: user_name, repo_name = _split_git_url(url) - repo_type = repo_type_of_url(url) + repo_type = get_repo_type(url) assert repo_type is not None, f"Invalid url {url}" - if repo_type == "github": + if repo_type == RepoType.GITHUB: return f"{user_name}-{repo_name}-{commit}" else: # git repo return f"gitpython-{repo_name}-{commit}" @@ -117,7 +123,7 @@ def _format_dirname(url: str, commit: str) -> str: def url_to_repo( url: str, num_retries: int = 2, - repo_type: Union[str, None] = None, + repo_type: Optional[RepoType] = None, tmp_dir: Union[Path] = None, ) -> Union[Repo, Repository]: """Convert a URL to a Repo object. @@ -125,7 +131,7 @@ def url_to_repo( Args: url (str): The URL of the repository. num_retries (int): Number of retries in case of failure. - repo_type (Optional[str]): The type of the repository. Defaults to None. + repo_type (Optional[RepoType]): The type of the repository. Defaults to None. tmp_dir (Optional[Path]): The temporary directory to clone the repo to. Defaults to None. Returns: @@ -133,16 +139,18 @@ def url_to_repo( """ url = normalize_url(url) backoff = 1 - tmp_dir = tmp_dir or os.path.join(TMP_DIR or "/tmp", str(uuid.uuid4())[:8]) - repo_type = repo_type or repo_type_of_url(url) + tmp_dir = tmp_dir or os.path.join( + TMP_DIR or "/tmp", next(tempfile._get_candidate_names()) + ) + repo_type = repo_type or get_repo_type(url) assert repo_type is not None, f"Invalid url {url}" while True: try: - if repo_type == "github": + if repo_type == RepoType.GITHUB: return GITHUB.get_repo("/".join(url.split("/")[-2:])) with working_directory(tmp_dir): repo_name = os.path.basename(url) - if repo_type == "local": + if repo_type == RepoType.LOCAL: assert is_git_repo(url), f"Local path {url} is not a git repo" shutil.copytree(url, repo_name) return Repo(repo_name) @@ -174,26 +182,22 @@ def cleanse_string(s: Union[str, Path]) -> str: def _to_commit_hash(repo: Union[Repository, Repo], label: str) -> str: """Convert a tag or branch to a commit hash.""" - # GitHub repository - if isinstance(repo, Repository): + if isinstance(repo, Repository): # GitHub repository logger.debug(f"Querying the commit hash for {repo.name} {label}") try: - commit = repo.get_commit(label).sha - except GithubException as e: + return repo.get_commit(label).sha + except GithubException as ex: raise ValueError(f"Invalid tag or branch: `{label}` for {repo.name}") - # Local or remote Git repository - elif isinstance(repo, Repo): + else: # Local or remote Git repository + assert isinstance(repo, Repo) logger.debug( f"Querying the commit hash for {repo.working_dir} repository {label}" ) try: # Resolve the label to a commit hash - commit = repo.commit(label).hexsha - except Exception as e: + return repo.commit(label).hexsha + except Exception as ex: raise ValueError(f"Error converting ref to commit hash: {e}") - else: - raise TypeError("Unsupported repository type") - return commit @dataclass(eq=True, unsafe_hash=True) @@ -493,8 +497,7 @@ class LeanGitRepo: url: str """The repo's URL. - It can be a GitHub URL that starts with https:// or git@github.com, - a local path, or any other valid Git URL. + It can be a GitHub URL that starts with https:// or git@github.com, a local path, or any other valid Git URL. """ commit: str @@ -512,23 +515,23 @@ class LeanGitRepo: """Required Lean version. """ - repo_type: str = field(init=False, repr=False) - """Type of the repo. It can be ``github``, ``local`` or ``remote``. + repo_type: RepoType = field(init=False, repr=False) + """Type of the repo. It can be ``GITHUB``, ``LOCAL`` or ``REMOTE``. """ def __post_init__(self) -> None: - repo_type = repo_type_of_url(self.url) + repo_type = get_repo_type(self.url) if repo_type is None: raise ValueError(f"{self.url} is not a valid URL") object.__setattr__(self, "repo_type", repo_type) object.__setattr__(self, "url", normalize_url(self.url, repo_type=repo_type)) # set repo and commit - if repo_type == "github": + if repo_type == RepoType.GITHUB: repo = url_to_repo(self.url, repo_type=repo_type) else: # get repo from cache rel_cache_dir = lambda url, commit: Path( - f"{REPO_CACHE_PREFIX}/{_format_dirname(url, commit)}/{self.name}" + f"{REPO_CACHE_PREFIX}/{_format_cache_dirname(url, commit)}/{self.name}" ) cache_repo_dir = repo_cache.get(rel_cache_dir(self.url, self.commit)) if cache_repo_dir is None: @@ -583,18 +586,17 @@ def is_lean4(self) -> bool: def commit_url(self) -> str: return f"{self.url}/tree/{self.commit}" - @property def format_dirname(self) -> Path: """Return the formatted cache directory name""" assert is_commit_hash(self.commit), f"Invalid commit hash: {self.commit}" - return Path(_format_dirname(self.url, self.commit)) + return Path(_format_cache_dirname(self.url, self.commit)) def show(self) -> None: """Show the repo in the default browser.""" webbrowser.open(self.commit_url) def exists(self) -> bool: - if self.repo_type != "github": + if self.repo_type != RepoType.GITHUB: repo = self.repo # git repo try: repo.commit(self.commit) @@ -746,7 +748,7 @@ def _get_config_url(self, filename: str) -> str: def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]: """Return the repo's files.""" - if self.repo_type == "github": + if self.repo_type == RepoType.GITHUB: config_url = self._get_config_url(filename) content = read_url(config_url, num_retries) else: diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index c0f3c4b..5a2d90d 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -204,7 +204,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: Returns: Path: The path of the traced repo in the cache, e.g. :file:`/home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib-2196ab363eb097c008d4497125e0dde23fb36db2` """ - rel_cache_dir = repo.format_dirname / repo.name + rel_cache_dir = repo.get_cache_dirname() / repo.name path = cache.get(rel_cache_dir) if path is None: logger.info(f"Tracing {repo}") diff --git a/tests/data_extraction/test_cache.py b/tests/data_extraction/test_cache.py index 4741239..e80d0f8 100644 --- a/tests/data_extraction/test_cache.py +++ b/tests/data_extraction/test_cache.py @@ -1,17 +1,15 @@ # test for cache manager from git import Repo -from lean_dojo.utils import working_directory from pathlib import Path +from lean_dojo.utils import working_directory from lean_dojo.data_extraction.cache import cache -def test_repo_cache(lean4_example_url, remote_example_url, example_commit_hash): +def test_local_repo_cache(lean4_example_url, example_commit_hash): # Note: The `git.Repo` requires the local repo to be cloned in a directory # all cached repos are stored in CACHE_DIR/repos prefix = "repos" repo_name = "lean4-example" - - # test local repo cache with working_directory() as tmp_dir: repo = Repo.clone_from(lean4_example_url, repo_name) repo.git.checkout(example_commit_hash) @@ -24,7 +22,10 @@ def test_repo_cache(lean4_example_url, remote_example_url, example_commit_hash): repo_cache_dir = cache.get(rel_cache_dir) assert repo_cache_dir is not None - # test remote repo cache + +def test_remote_repo_cache(remote_example_url): + prefix = "repos" + repo_name = "lean4-example" with working_directory() as tmp_dir: repo = Repo.clone_from(remote_example_url, repo_name) tmp_remote_dir = tmp_dir / repo_name diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py index 04abb7f..818182c 100644 --- a/tests/data_extraction/test_lean_repo.py +++ b/tests/data_extraction/test_lean_repo.py @@ -1,17 +1,17 @@ -# test for the class `LeanGitRepo` -from lean_dojo import LeanGitRepo +# Tests for the class `LeanGitRepo` from git import Repo +from lean_dojo import LeanGitRepo from github.Repository import Repository -from lean_dojo.utils import working_directory from lean_dojo.data_extraction.lean import ( _to_commit_hash, - repo_type_of_url, + get_repo_type, url_to_repo, get_latest_commit, is_commit_hash, GITHUB, - LEAN4_REPO, + RepoType, ) +from lean_dojo.utils import working_directory def test_github_type(lean4_example_url, example_commit_hash): @@ -21,11 +21,11 @@ def test_github_type(lean4_example_url, example_commit_hash): gh_cm_hash = get_latest_commit(lean4_example_url) assert is_commit_hash(gh_cm_hash) - ## url_to_repo & repo_type_of_url + ## url_to_repo & get_repo_type github_repo = url_to_repo(lean4_example_url) - assert repo_type_of_url(lean4_example_url) == "github" - assert repo_type_of_url("git@github.com:yangky11/lean4-example.git") == "github" - assert repo_type_of_url("git@github.com:yangky11/lean4-example") == "github" + assert get_repo_type(lean4_example_url) == RepoType.GITHUB + assert get_repo_type("git@github.com:yangky11/lean4-example.git") == RepoType.GITHUB + assert get_repo_type("git@github.com:yangky11/lean4-example") == RepoType.GITHUB assert isinstance(github_repo, Repository) assert github_repo.name == repo_name @@ -46,7 +46,7 @@ def test_github_type(lean4_example_url, example_commit_hash): LeanGitRepo(lean4_example_url, "main") # init with branch repo = LeanGitRepo(lean4_example_url, example_commit_hash) assert repo.url == lean4_example_url - assert repo.repo_type == "github" + assert repo.repo_type == RepoType.GITHUB assert repo.commit == example_commit_hash assert repo.exists() assert repo.name == repo_name @@ -54,14 +54,14 @@ def test_github_type(lean4_example_url, example_commit_hash): assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" # cache name assert isinstance(repo.repo, Repository) - assert str(repo.format_dirname) == f"yangky11-{repo_name}-{example_commit_hash}" + assert str(repo.format_dirname()) == f"yangky11-{repo_name}-{example_commit_hash}" def test_remote_type(remote_example_url, example_commit_hash): repo_name = "lean4-example" remote_repo = url_to_repo(remote_example_url) - assert repo_type_of_url(remote_example_url) == "remote" + assert get_repo_type(remote_example_url) == RepoType.REMOTE assert isinstance(remote_repo, Repo) re_cm_hash = get_latest_commit(remote_example_url) assert re_cm_hash == get_latest_commit(str(remote_repo.working_dir)) @@ -75,7 +75,7 @@ def test_remote_type(remote_example_url, example_commit_hash): LeanGitRepo(remote_example_url, "main") repo = LeanGitRepo(remote_example_url, example_commit_hash) assert repo.url == remote_example_url - assert repo.repo_type == "remote" + assert repo.repo_type == RepoType.REMOTE assert repo.commit == example_commit_hash assert repo.exists() assert repo.name == repo_name @@ -83,7 +83,7 @@ def test_remote_type(remote_example_url, example_commit_hash): assert repo.commit_url == f"{remote_example_url}/tree/{example_commit_hash}" # cache name assert isinstance(repo.repo, Repo) - assert str(repo.format_dirname) == f"gitpython-{repo_name}-{example_commit_hash}" + assert str(repo.format_dirname()) == f"gitpython-{repo_name}-{example_commit_hash}" def test_local_type(lean4_example_url, example_commit_hash): @@ -98,9 +98,9 @@ def test_local_type(lean4_example_url, example_commit_hash): local_url = str((tmp_dir / repo_name).absolute()) assert get_latest_commit(local_url) == gh_cm_hash - ## url_to_repo & repo_type_of_url - local_repo = url_to_repo(local_url, repo_type="local") - assert repo_type_of_url(local_url) == "local" + ## url_to_repo & get_repo_type + local_repo = url_to_repo(local_url, repo_type=RepoType.LOCAL) + assert get_repo_type(local_url) == RepoType.LOCAL assert isinstance(local_repo, Repo) assert ( local_repo.working_dir != local_url @@ -123,7 +123,7 @@ def test_local_type(lean4_example_url, example_commit_hash): repo = LeanGitRepo(local_url, example_commit_hash) repo2 = LeanGitRepo.from_path(local_url) # test from_path assert repo.url == local_url == repo2.url - assert repo.repo_type == "local" == repo2.repo_type + assert repo.repo_type == RepoType.LOCAL == repo2.repo_type assert repo.commit == example_commit_hash and repo2.commit == gh_cm_hash assert repo.exists() and repo2.exists() assert repo.name == repo_name == repo2.name @@ -131,6 +131,6 @@ def test_local_type(lean4_example_url, example_commit_hash): # cache name assert isinstance(repo.repo, Repo) and isinstance(repo2.repo, Repo) assert ( - str(repo.format_dirname) == f"gitpython-{repo_name}-{example_commit_hash}" + str(repo.format_dirname()) == f"gitpython-{repo_name}-{example_commit_hash}" ) - assert str(repo2.format_dirname) == f"gitpython-{repo_name}-{gh_cm_hash}" + assert str(repo2.format_dirname()) == f"gitpython-{repo_name}-{gh_cm_hash}" diff --git a/tests/data_extraction/test_trace.py b/tests/data_extraction/test_trace.py index 24df62f..04bef5f 100644 --- a/tests/data_extraction/test_trace.py +++ b/tests/data_extraction/test_trace.py @@ -2,25 +2,25 @@ from lean_dojo import * from lean_dojo.data_extraction.cache import cache from lean_dojo.utils import working_directory -from lean_dojo.data_extraction.lean import url_to_repo +from lean_dojo.data_extraction.lean import RepoType from git import Repo def test_github_trace(lean4_example_url): # github github_repo = LeanGitRepo(lean4_example_url, "main") - assert github_repo.repo_type == "github" + assert github_repo.repo_type == RepoType.GITHUB trace_repo = trace(github_repo) - path = cache.get(github_repo.format_dirname / github_repo.name) + path = cache.get(github_repo.format_dirname() / github_repo.name) assert path is not None def test_remote_trace(remote_example_url): # remote remote_repo = LeanGitRepo(remote_example_url, "main") - assert remote_repo.repo_type == "remote" + assert remote_repo.repo_type == RepoType.REMOTE trace_repo = trace(remote_repo) - path = cache.get(remote_repo.format_dirname / remote_repo.name) + path = cache.get(remote_repo.format_dirname() / remote_repo.name) assert path is not None @@ -33,9 +33,9 @@ def test_local_trace(lean4_example_url): local_url = str((tmp_dir / "lean4-example").absolute()) local_repo = LeanGitRepo(local_dir, "main") assert local_repo.url == local_url - assert local_repo.repo_type == "local" + assert local_repo.repo_type == RepoType.LOCAL trace_repo = trace(local_repo) - path = cache.get(local_repo.format_dirname / local_repo.name) + path = cache.get(local_repo.format_dirname() / local_repo.name) assert path is not None diff --git a/tests/interaction/test_interaction.py b/tests/interaction/test_interaction.py index 3eff4d9..83b1c45 100644 --- a/tests/interaction/test_interaction.py +++ b/tests/interaction/test_interaction.py @@ -1,15 +1,17 @@ -from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem -from lean_dojo.utils import working_directory -from git import Repo import os +from git import Repo +from lean_dojo.utils import working_directory +from lean_dojo.data_extraction.lean import RepoType +from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem + # Avoid using remote cache -os.environ["DISABLE_REMOTE_CACHE"] = "true" +os.environ["DISABLE_REMOTE_CACHE"] = "1" def test_github_interact(lean4_example_url): repo = LeanGitRepo(url=lean4_example_url, commit="main") - assert repo.repo_type == "github" + assert repo.repo_type == RepoType.GITHUB theorem = Theorem(repo, "Lean4Example.lean", "hello_world") # initial state dojo, state_0 = Dojo(theorem).__enter__() @@ -26,7 +28,7 @@ def test_github_interact(lean4_example_url): def test_remote_interact(remote_example_url): repo = LeanGitRepo(url=remote_example_url, commit="main") - assert repo.repo_type == "remote" + assert repo.repo_type == RepoType.REMOTE theorem = Theorem(repo, "Lean4Example.lean", "hello_world") # initial state dojo, state_0 = Dojo(theorem).__enter__() @@ -49,7 +51,7 @@ def test_local_interact(lean4_example_url): local_dir = str((tmp_dir / "lean4-example")) repo = LeanGitRepo(local_dir, commit="main") - assert repo.repo_type == "local" + assert repo.repo_type == RepoType.LOCAL theorem = Theorem(repo, "Lean4Example.lean", "hello_world") # initial state dojo, state_0 = Dojo(theorem).__enter__()