Skip to content

Commit

Permalink
minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Aug 5, 2024
1 parent 897c6f1 commit 592f9b5
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 87 deletions.
96 changes: 49 additions & 47 deletions src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import re
import os
import json
import uuid
import toml
import time
import shutil
import urllib
import webbrowser
import shutil
from enum import Enum
from pathlib import Path
from loguru import logger
from functools import cache
Expand Down Expand Up @@ -57,37 +56,44 @@
REPO_CACHE_PREFIX = "repos"


def normalize_url(url: str, repo_type: str = "github") -> str:
if repo_type == "local":
return os.path.abspath(url) # Convert to absolute path if local
return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`.
class RepoType(Enum):
GITHUB = 0
REMOTE = 1 # Remote but not GitHub.
LOCAL = 2


def normalize_url(url: str, repo_type: RepoType = RepoType.GITHUB) -> str:
if repo_type == RepoType.LOCAL: # Convert to absolute path if local.
return os.path.abspath(url)
# Remove trailing `/`.
return _URL_REGEX.fullmatch(url)["url"]

def repo_type_of_url(url: str) -> Union[str, None]:

def get_repo_type(url: str) -> Optional[RepoType]:
"""Get the type of the repository.
Args:
url (str): The URL of the repository.
Returns:
str: The type of the repository.
Optional[str]: The type of the repository (None if the repo cannot be found).
"""
m = _SSH_TO_HTTPS_REGEX.match(url)
url = f"https://github.com/{m.group(1)}/{m.group(2)}" if m else url
parsed_url = urllib.parse.urlparse(url)
if parsed_url.scheme in ["http", "https"]:
# case 1 - GitHub URL
# Case 1 - GitHub URL.
if "github.com" in url:
if not url.startswith("https://"):
logger.warning(f"{url} should start with https://")
return
return None
else:
return "github"
# case 2 - remote URL
elif url_exists(url): # not check whether it is a git URL
return "remote"
# case 3 - local path
return RepoType.GITHUB
# Case 2 - remote URL.
elif url_exists(url): # Not check whether it is a git URL
return RepoType.REMOTE
# Case 3 - local path
elif is_git_repo(Path(parsed_url.path)):
return "local"
return RepoType.LOCAL
logger.warning(f"{url} is not a valid URL")
return None

Expand All @@ -103,11 +109,11 @@ def _split_git_url(url: str) -> Tuple[str, str]:
return user_name, repo_name


def _format_dirname(url: str, commit: str) -> str:
def _format_cache_dirname(url: str, commit: str) -> str:
user_name, repo_name = _split_git_url(url)
repo_type = repo_type_of_url(url)
repo_type = get_repo_type(url)
assert repo_type is not None, f"Invalid url {url}"
if repo_type == "github":
if repo_type == RepoType.GITHUB:
return f"{user_name}-{repo_name}-{commit}"
else: # git repo
return f"gitpython-{repo_name}-{commit}"
Expand All @@ -117,32 +123,34 @@ def _format_dirname(url: str, commit: str) -> str:
def url_to_repo(
url: str,
num_retries: int = 2,
repo_type: Union[str, None] = None,
repo_type: Optional[RepoType] = None,
tmp_dir: Union[Path] = None,
) -> Union[Repo, Repository]:
"""Convert a URL to a Repo object.
Args:
url (str): The URL of the repository.
num_retries (int): Number of retries in case of failure.
repo_type (Optional[str]): The type of the repository. Defaults to None.
repo_type (Optional[RepoType]): The type of the repository. Defaults to None.
tmp_dir (Optional[Path]): The temporary directory to clone the repo to. Defaults to None.
Returns:
Repo: A Git Repo object.
"""
url = normalize_url(url)
backoff = 1
tmp_dir = tmp_dir or os.path.join(TMP_DIR or "/tmp", str(uuid.uuid4())[:8])
repo_type = repo_type or repo_type_of_url(url)
tmp_dir = tmp_dir or os.path.join(
TMP_DIR or "/tmp", next(tempfile._get_candidate_names())
)
repo_type = repo_type or get_repo_type(url)
assert repo_type is not None, f"Invalid url {url}"
while True:
try:
if repo_type == "github":
if repo_type == RepoType.GITHUB:
return GITHUB.get_repo("/".join(url.split("/")[-2:]))
with working_directory(tmp_dir):
repo_name = os.path.basename(url)
if repo_type == "local":
if repo_type == RepoType.LOCAL:
assert is_git_repo(url), f"Local path {url} is not a git repo"
shutil.copytree(url, repo_name)
return Repo(repo_name)
Expand Down Expand Up @@ -174,26 +182,22 @@ def cleanse_string(s: Union[str, Path]) -> str:

def _to_commit_hash(repo: Union[Repository, Repo], label: str) -> str:
"""Convert a tag or branch to a commit hash."""
# GitHub repository
if isinstance(repo, Repository):
if isinstance(repo, Repository): # GitHub repository
logger.debug(f"Querying the commit hash for {repo.name} {label}")
try:
commit = repo.get_commit(label).sha
except GithubException as e:
return repo.get_commit(label).sha
except GithubException as ex:
raise ValueError(f"Invalid tag or branch: `{label}` for {repo.name}")
# Local or remote Git repository
elif isinstance(repo, Repo):
else: # Local or remote Git repository
assert isinstance(repo, Repo)
logger.debug(
f"Querying the commit hash for {repo.working_dir} repository {label}"
)
try:
# Resolve the label to a commit hash
commit = repo.commit(label).hexsha
except Exception as e:
return repo.commit(label).hexsha
except Exception as ex:
raise ValueError(f"Error converting ref to commit hash: {e}")
else:
raise TypeError("Unsupported repository type")
return commit


@dataclass(eq=True, unsafe_hash=True)
Expand Down Expand Up @@ -493,8 +497,7 @@ class LeanGitRepo:
url: str
"""The repo's URL.
It can be a GitHub URL that starts with https:// or [email protected],
a local path, or any other valid Git URL.
It can be a GitHub URL that starts with https:// or [email protected], a local path, or any other valid Git URL.
"""

commit: str
Expand All @@ -512,23 +515,23 @@ class LeanGitRepo:
"""Required Lean version.
"""

repo_type: str = field(init=False, repr=False)
"""Type of the repo. It can be ``github``, ``local`` or ``remote``.
repo_type: RepoType = field(init=False, repr=False)
"""Type of the repo. It can be ``GITHUB``, ``LOCAL`` or ``REMOTE``.
"""

def __post_init__(self) -> None:
repo_type = repo_type_of_url(self.url)
repo_type = get_repo_type(self.url)
if repo_type is None:
raise ValueError(f"{self.url} is not a valid URL")
object.__setattr__(self, "repo_type", repo_type)
object.__setattr__(self, "url", normalize_url(self.url, repo_type=repo_type))
# set repo and commit
if repo_type == "github":
if repo_type == RepoType.GITHUB:
repo = url_to_repo(self.url, repo_type=repo_type)
else:
# get repo from cache
rel_cache_dir = lambda url, commit: Path(
f"{REPO_CACHE_PREFIX}/{_format_dirname(url, commit)}/{self.name}"
f"{REPO_CACHE_PREFIX}/{_format_cache_dirname(url, commit)}/{self.name}"
)
cache_repo_dir = repo_cache.get(rel_cache_dir(self.url, self.commit))
if cache_repo_dir is None:
Expand Down Expand Up @@ -583,18 +586,17 @@ def is_lean4(self) -> bool:
def commit_url(self) -> str:
return f"{self.url}/tree/{self.commit}"

@property
def format_dirname(self) -> Path:
"""Return the formatted cache directory name"""
assert is_commit_hash(self.commit), f"Invalid commit hash: {self.commit}"
return Path(_format_dirname(self.url, self.commit))
return Path(_format_cache_dirname(self.url, self.commit))

def show(self) -> None:
"""Show the repo in the default browser."""
webbrowser.open(self.commit_url)

def exists(self) -> bool:
if self.repo_type != "github":
if self.repo_type != RepoType.GITHUB:
repo = self.repo # git repo
try:
repo.commit(self.commit)
Expand Down Expand Up @@ -746,7 +748,7 @@ def _get_config_url(self, filename: str) -> str:

def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]:
"""Return the repo's files."""
if self.repo_type == "github":
if self.repo_type == RepoType.GITHUB:
config_url = self._get_config_url(filename)
content = read_url(config_url, num_retries)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/lean_dojo/data_extraction/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path:
Returns:
Path: The path of the traced repo in the cache, e.g. :file:`/home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib-2196ab363eb097c008d4497125e0dde23fb36db2`
"""
rel_cache_dir = repo.format_dirname / repo.name
rel_cache_dir = repo.get_cache_dirname() / repo.name
path = cache.get(rel_cache_dir)
if path is None:
logger.info(f"Tracing {repo}")
Expand Down
11 changes: 6 additions & 5 deletions tests/data_extraction/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# test for cache manager
from git import Repo
from lean_dojo.utils import working_directory
from pathlib import Path
from lean_dojo.utils import working_directory
from lean_dojo.data_extraction.cache import cache


def test_repo_cache(lean4_example_url, remote_example_url, example_commit_hash):
def test_local_repo_cache(lean4_example_url, example_commit_hash):
# Note: The `git.Repo` requires the local repo to be cloned in a directory
# all cached repos are stored in CACHE_DIR/repos
prefix = "repos"
repo_name = "lean4-example"

# test local repo cache
with working_directory() as tmp_dir:
repo = Repo.clone_from(lean4_example_url, repo_name)
repo.git.checkout(example_commit_hash)
Expand All @@ -24,7 +22,10 @@ def test_repo_cache(lean4_example_url, remote_example_url, example_commit_hash):
repo_cache_dir = cache.get(rel_cache_dir)
assert repo_cache_dir is not None

# test remote repo cache

def test_remote_repo_cache(remote_example_url):
prefix = "repos"
repo_name = "lean4-example"
with working_directory() as tmp_dir:
repo = Repo.clone_from(remote_example_url, repo_name)
tmp_remote_dir = tmp_dir / repo_name
Expand Down
40 changes: 20 additions & 20 deletions tests/data_extraction/test_lean_repo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# test for the class `LeanGitRepo`
from lean_dojo import LeanGitRepo
# Tests for the class `LeanGitRepo`
from git import Repo
from lean_dojo import LeanGitRepo
from github.Repository import Repository
from lean_dojo.utils import working_directory
from lean_dojo.data_extraction.lean import (
_to_commit_hash,
repo_type_of_url,
get_repo_type,
url_to_repo,
get_latest_commit,
is_commit_hash,
GITHUB,
LEAN4_REPO,
RepoType,
)
from lean_dojo.utils import working_directory


def test_github_type(lean4_example_url, example_commit_hash):
Expand All @@ -21,11 +21,11 @@ def test_github_type(lean4_example_url, example_commit_hash):
gh_cm_hash = get_latest_commit(lean4_example_url)
assert is_commit_hash(gh_cm_hash)

## url_to_repo & repo_type_of_url
## url_to_repo & get_repo_type
github_repo = url_to_repo(lean4_example_url)
assert repo_type_of_url(lean4_example_url) == "github"
assert repo_type_of_url("[email protected]:yangky11/lean4-example.git") == "github"
assert repo_type_of_url("[email protected]:yangky11/lean4-example") == "github"
assert get_repo_type(lean4_example_url) == RepoType.GITHUB
assert get_repo_type("[email protected]:yangky11/lean4-example.git") == RepoType.GITHUB
assert get_repo_type("[email protected]:yangky11/lean4-example") == RepoType.GITHUB
assert isinstance(github_repo, Repository)
assert github_repo.name == repo_name

Expand All @@ -46,22 +46,22 @@ def test_github_type(lean4_example_url, example_commit_hash):
LeanGitRepo(lean4_example_url, "main") # init with branch
repo = LeanGitRepo(lean4_example_url, example_commit_hash)
assert repo.url == lean4_example_url
assert repo.repo_type == "github"
assert repo.repo_type == RepoType.GITHUB
assert repo.commit == example_commit_hash
assert repo.exists()
assert repo.name == repo_name
assert repo.lean_version == "v4.7.0"
assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}"
# cache name
assert isinstance(repo.repo, Repository)
assert str(repo.format_dirname) == f"yangky11-{repo_name}-{example_commit_hash}"
assert str(repo.format_dirname()) == f"yangky11-{repo_name}-{example_commit_hash}"


def test_remote_type(remote_example_url, example_commit_hash):
repo_name = "lean4-example"

remote_repo = url_to_repo(remote_example_url)
assert repo_type_of_url(remote_example_url) == "remote"
assert get_repo_type(remote_example_url) == RepoType.REMOTE
assert isinstance(remote_repo, Repo)
re_cm_hash = get_latest_commit(remote_example_url)
assert re_cm_hash == get_latest_commit(str(remote_repo.working_dir))
Expand All @@ -75,15 +75,15 @@ def test_remote_type(remote_example_url, example_commit_hash):
LeanGitRepo(remote_example_url, "main")
repo = LeanGitRepo(remote_example_url, example_commit_hash)
assert repo.url == remote_example_url
assert repo.repo_type == "remote"
assert repo.repo_type == RepoType.REMOTE
assert repo.commit == example_commit_hash
assert repo.exists()
assert repo.name == repo_name
assert repo.lean_version == "v4.7.0"
assert repo.commit_url == f"{remote_example_url}/tree/{example_commit_hash}"
# cache name
assert isinstance(repo.repo, Repo)
assert str(repo.format_dirname) == f"gitpython-{repo_name}-{example_commit_hash}"
assert str(repo.format_dirname()) == f"gitpython-{repo_name}-{example_commit_hash}"


def test_local_type(lean4_example_url, example_commit_hash):
Expand All @@ -98,9 +98,9 @@ def test_local_type(lean4_example_url, example_commit_hash):
local_url = str((tmp_dir / repo_name).absolute())
assert get_latest_commit(local_url) == gh_cm_hash

## url_to_repo & repo_type_of_url
local_repo = url_to_repo(local_url, repo_type="local")
assert repo_type_of_url(local_url) == "local"
## url_to_repo & get_repo_type
local_repo = url_to_repo(local_url, repo_type=RepoType.LOCAL)
assert get_repo_type(local_url) == RepoType.LOCAL
assert isinstance(local_repo, Repo)
assert (
local_repo.working_dir != local_url
Expand All @@ -123,14 +123,14 @@ def test_local_type(lean4_example_url, example_commit_hash):
repo = LeanGitRepo(local_url, example_commit_hash)
repo2 = LeanGitRepo.from_path(local_url) # test from_path
assert repo.url == local_url == repo2.url
assert repo.repo_type == "local" == repo2.repo_type
assert repo.repo_type == RepoType.LOCAL == repo2.repo_type
assert repo.commit == example_commit_hash and repo2.commit == gh_cm_hash
assert repo.exists() and repo2.exists()
assert repo.name == repo_name == repo2.name
assert repo.lean_version == "v4.7.0"
# cache name
assert isinstance(repo.repo, Repo) and isinstance(repo2.repo, Repo)
assert (
str(repo.format_dirname) == f"gitpython-{repo_name}-{example_commit_hash}"
str(repo.format_dirname()) == f"gitpython-{repo_name}-{example_commit_hash}"
)
assert str(repo2.format_dirname) == f"gitpython-{repo_name}-{gh_cm_hash}"
assert str(repo2.format_dirname()) == f"gitpython-{repo_name}-{gh_cm_hash}"
Loading

0 comments on commit 592f9b5

Please sign in to comment.