diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 7d3fbd4..5e56050 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -539,7 +539,7 @@ def from_traced_file( def _from_lean4_traced_file( cls, root_dir: Path, json_path: Path, repo: LeanGitRepo ) -> "TracedFile": - lean_path = to_lean_path(root_dir, json_path, repo) + lean_path = to_lean_path(root_dir, json_path) lean_file = LeanFile(root_dir, lean_path) data = json.load(json_path.open()) @@ -922,7 +922,7 @@ def from_xml( root_dir = Path(root_dir) path = Path(path) assert path.suffixes == [".trace", ".xml"] - lean_path = to_lean_path(root_dir, path, repo) + lean_path = to_lean_path(root_dir, path) lean_file = LeanFile(root_dir, lean_path) tree = etree.parse(path).getroot() diff --git a/src/lean_dojo/utils.py b/src/lean_dojo/utils.py index 3efad4c..57f6240 100644 --- a/src/lean_dojo/utils.py +++ b/src/lean_dojo/utils.py @@ -71,7 +71,7 @@ def ray_actor_pool( """ assert not ray.is_initialized() ray.init() - pool = ActorPool([actor_cls.remote(*args, **kwargs) for _ in range(NUM_WORKERS)]) + pool = ActorPool([actor_cls.remote(*args, **kwargs) for _ in range(NUM_WORKERS)]) # type: ignore try: yield pool finally: @@ -154,8 +154,7 @@ def is_optional_type(tp: type) -> bool: def remove_optional_type(tp: type) -> type: """Given Optional[X], return X.""" - if typing.get_origin(tp) != Union: - return False + assert typing.get_origin(tp) == Union args = typing.get_args(tp) if len(args) == 2 and args[1] == type(None): return args[0] @@ -169,11 +168,11 @@ def read_url(url: str, num_retries: int = 2) -> str: backoff = 1 while True: try: - request = urllib.request.Request(url) + request = urllib.request.Request(url) # type: ignore gh_token = os.getenv("GITHUB_ACCESS_TOKEN") if gh_token is not None: request.add_header("Authorization", f"token {gh_token}") - with urllib.request.urlopen(request) as f: + with urllib.request.urlopen(request) as f: # type: ignore return f.read().decode() except Exception as ex: if num_retries <= 0: @@ -188,13 +187,13 @@ def read_url(url: str, num_retries: int = 2) -> str: def url_exists(url: str) -> bool: """Return True if the URL ``url`` exists, using the GITHUB_ACCESS_TOKEN for authentication if provided.""" try: - request = urllib.request.Request(url) + request = urllib.request.Request(url) # type: ignore gh_token = os.getenv("GITHUB_ACCESS_TOKEN") if gh_token is not None: request.add_header("Authorization", f"token {gh_token}") - with urllib.request.urlopen(request) as _: + with urllib.request.urlopen(request) as _: # type: ignore return True - except urllib.error.HTTPError: + except urllib.error.HTTPError: # type: ignore return False @@ -260,7 +259,7 @@ def to_json_path(root_dir: Path, path: Path, repo) -> Path: return _from_lean_path(root_dir, path, repo, ext=".ast.json") -def to_lean_path(root_dir: Path, path: Path, repo) -> bool: +def to_lean_path(root_dir: Path, path: Path) -> Path: if path.is_absolute(): path = path.relative_to(root_dir)