diff --git a/tests/fixtures/vision/checker/models_import.py b/tests/fixtures/vision/checker/models_import.py index 8eae98e..3a16490 100644 --- a/tests/fixtures/vision/checker/models_import.py +++ b/tests/fixtures/vision/checker/models_import.py @@ -3,3 +3,4 @@ from torchvision.models import resnet50, resnet101 import torchvision.models from torchvision.models import * +import torchvision.models as models, torch diff --git a/tests/fixtures/vision/checker/models_import.txt b/tests/fixtures/vision/checker/models_import.txt index 864cf35..7a517da 100644 --- a/tests/fixtures/vision/checker/models_import.txt +++ b/tests/fixtures/vision/checker/models_import.txt @@ -1 +1,2 @@ 1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. +6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. diff --git a/tests/fixtures/vision/codemod/models_import.py b/tests/fixtures/vision/codemod/models_import.py new file mode 100644 index 0000000..6b75141 --- /dev/null +++ b/tests/fixtures/vision/codemod/models_import.py @@ -0,0 +1,5 @@ +import torchvision.models as models +import torchvision.models as cnn + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/tests/fixtures/vision/codemod/models_import.py.out b/tests/fixtures/vision/codemod/models_import.py.out new file mode 100644 index 0000000..53269c1 --- /dev/null +++ b/tests/fixtures/vision/codemod/models_import.py.out @@ -0,0 +1,5 @@ +from torchvision import models +import torchvision.models as cnn + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index ba5a325..7ccbebb 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -5,10 +5,16 @@ class TorchVisionModelsImportVisitor(TorchVisitor): ERROR_CODE = "TOR203" + MESSAGE = ( + "Consider replacing 'import torchvision.models as models' " + "with 'from torchvision import models'." + ) def visit_Import(self, node: cst.Import) -> None: + replacement = None for imported_item in node.names: if isinstance(imported_item.name, cst.Attribute): + # TODO refactor using libcst.matchers.matches if ( isinstance(imported_item.name.value, cst.Name) and imported_item.name.value.value == "torchvision" @@ -21,20 +27,20 @@ def visit_Import(self, node: cst.Import) -> None: position = self.get_metadata( cst.metadata.WhitespaceInclusivePositionProvider, node ) - replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], + ) self.violations.append( LintViolation( error_code=self.ERROR_CODE, - message=( - "Consider replacing 'import torchvision.models as" - " models' with 'from torchvision import models'." - ), + message=self.MESSAGE, line=position.start.line, column=position.start.column, node=node, replacement=replacement ) ) + break