Skip to content

Commit

Permalink
Enable mypy check_untyped_defs (pytorch-labs#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 authored May 9, 2024
1 parent 615e365 commit fdad986
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ exclude = "tests/fixtures/*"

[tool.mypy]
exclude = ["tests/fixtures", "build"]
check_untyped_defs = true

[tool.setuptools.dynamic]
version = {attr = "torchfix.torchfix.__version__"}
5 changes: 3 additions & 2 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def _codemod_results(source_path):
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
new_module = codemod.transform_module(context, code)
if isinstance(new_module, codemod.TransformFailure):
if isinstance(new_module, codemod.TransformSuccess):
return new_module.code
elif isinstance(new_module, codemod.TransformFailure):
raise new_module.error
return new_module.code


def test_empty():
Expand Down
4 changes: 3 additions & 1 deletion torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import libcst as cst
import libcst.codemod as codemod

from .common import deep_multi_replace
from .common import deep_multi_replace, TorchVisitor
from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor
from .visitors.internal import TorchScopedLibraryVisitor

Expand Down Expand Up @@ -44,6 +44,7 @@
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
assert issubclass(cls, TorchVisitor)
codes |= {error.error_code for error in cls.ERRORS}
return codes

Expand Down Expand Up @@ -79,6 +80,7 @@ def get_visitors_with_error_codes(error_codes):
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
assert issubclass(visitor_cls, TorchVisitor)
if any(error_code == err.error_code for err in visitor_cls.ERRORS):
visitor_classes.add(visitor_cls)
found = True
Expand Down
1 change: 1 addition & 0 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def read_deprecated_config(path=None):
deprecated_config = {}
if path is not None:
data = pkgutil.get_data("torchfix", path)
assert data is not None
for item in yaml.load(data, yaml.SafeLoader):
deprecated_config[item["name"]] = item
return deprecated_config
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor):

def visit_Call(self, node):
def _new_arg_and_import(
old_arg: cst.Arg, is_backbone: bool
old_arg: Optional[cst.Arg], is_backbone: bool
) -> Optional[cst.Arg]:
old_arg_name = "pretrained_backbone" if is_backbone else "pretrained"
if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS:
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def visit_Attribute(self, node):
if len(qualified_names) != 1:
return

self._maybe_add_violation(qualified_names.pop().name, node)
self._maybe_add_violation(list(qualified_names)[0].name, node)

0 comments on commit fdad986

Please sign in to comment.