From 1d57713dafbb85684370a8cef178be68531025f7 Mon Sep 17 00:00:00 2001 From: Suwen Ge Date: Sun, 3 Mar 2024 10:42:45 -0800 Subject: [PATCH] Move torchvision.models visitor to vision dir --- .../codemod/torchvision_models.py | 22 ---------- .../codemod/torchvision_models.py.out | 22 ---------- .../fixtures/vision/checker/models_import.py | 5 +++ .../fixtures/vision/checker/models_import.txt | 1 + torchfix/torchfix.py | 8 ++-- .../visitors/deprecated_symbols/__init__.py | 25 ------------ torchfix/visitors/vision/__init__.py | 1 + torchfix/visitors/vision/models_import.py | 40 +++++++++++++++++++ 8 files changed, 50 insertions(+), 74 deletions(-) delete mode 100644 tests/fixtures/deprecated_symbols/codemod/torchvision_models.py delete mode 100644 tests/fixtures/deprecated_symbols/codemod/torchvision_models.py.out create mode 100644 tests/fixtures/vision/checker/models_import.py create mode 100644 tests/fixtures/vision/checker/models_import.txt create mode 100644 torchfix/visitors/vision/models_import.py diff --git a/tests/fixtures/deprecated_symbols/codemod/torchvision_models.py b/tests/fixtures/deprecated_symbols/codemod/torchvision_models.py deleted file mode 100644 index d14921f..0000000 --- a/tests/fixtures/deprecated_symbols/codemod/torchvision_models.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import torchvision.models as models - -import torch.autograd.profiler as profiler - -for with_cuda in [False, True]: - model = models.test1()#resnet18() - inputs = torch.randn(5, 3, 224, 224) - sort_key = "self_cpu_memory_usage" - if with_cuda and torch.cuda.is_available(): - model = model.cuda() - inputs = inputs.cuda() - sort_key = "self_cuda_memory_usage" - print("Profiling CUDA Resnet model") - else: - print("Profiling CPU Resnet model") - - with profiler.profile(profile_memory=True, record_shapes=True) as prof: - with profiler.record_function("root"): - model(inputs) - - print(prof.key_averages(group_by_input_shape=True).table(sort_by=sort_key, row_limit=-1)) diff --git a/tests/fixtures/deprecated_symbols/codemod/torchvision_models.py.out b/tests/fixtures/deprecated_symbols/codemod/torchvision_models.py.out deleted file mode 100644 index 290b4e6..0000000 --- a/tests/fixtures/deprecated_symbols/codemod/torchvision_models.py.out +++ /dev/null @@ -1,22 +0,0 @@ -import torch -from torchvision import models - -import torch.autograd.profiler as profiler - -for with_cuda in [False, True]: - model = models.test1()#resnet18() - inputs = torch.randn(5, 3, 224, 224) - sort_key = "self_cpu_memory_usage" - if with_cuda and torch.cuda.is_available(): - model = model.cuda() - inputs = inputs.cuda() - sort_key = "self_cuda_memory_usage" - print("Profiling CUDA Resnet model") - else: - print("Profiling CPU Resnet model") - - with profiler.profile(profile_memory=True, record_shapes=True) as prof: - with profiler.record_function("root"): - model(inputs) - - print(prof.key_averages(group_by_input_shape=True).table(sort_by=sort_key, row_limit=-1)) diff --git a/tests/fixtures/vision/checker/models_import.py b/tests/fixtures/vision/checker/models_import.py new file mode 100644 index 0000000..8eae98e --- /dev/null +++ b/tests/fixtures/vision/checker/models_import.py @@ -0,0 +1,5 @@ +import torchvision.models as models +import torchvision.models as cnn +from torchvision.models import resnet50, resnet101 +import torchvision.models +from torchvision.models import * diff --git a/tests/fixtures/vision/checker/models_import.txt b/tests/fixtures/vision/checker/models_import.txt new file mode 100644 index 0000000..29bbc7b --- /dev/null +++ b/tests/fixtures/vision/checker/models_import.txt @@ -0,0 +1 @@ +1:1: TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 20eda7a..490caf4 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -9,7 +9,6 @@ from .visitors.deprecated_symbols import ( TorchDeprecatedSymbolsVisitor, _UpdateFunctorchImports, - _UpdateTorchvisionModelsImports, ) from .visitors.internal import TorchScopedLibraryVisitor @@ -20,6 +19,7 @@ from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, + TorchVisionModelsImportVisitor, ) from .visitors.security import TorchUnsafeLoadVisitor @@ -36,6 +36,7 @@ TorchSynchronizedDataLoaderVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, + TorchVisionModelsImportVisitor, TorchUnsafeLoadVisitor, TorchReentrantCheckpointVisitor, ] @@ -230,11 +231,8 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: update_functorch_imports_visitor = _UpdateFunctorchImports() new_module = new_module.visit(update_functorch_imports_visitor) - update_torchvision_models_visitor = _UpdateTorchvisionModelsImports() - new_module = new_module.visit(update_torchvision_models_visitor) - if fixes_count == 0 and not update_functorch_imports_visitor.changed \ - and not update_torchvision_models_visitor.changed: + if fixes_count == 0 and not update_functorch_imports_visitor.changed: raise codemod.SkipFile("No changes") return new_module diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 532f7f0..93a9082 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -117,28 +117,3 @@ def leave_ImportFrom( self.changed = True return updated_node.with_changes(module=cst.parse_expression("torch.func")) return updated_node - -# TODO: refactor/generalize this. -class _UpdateTorchvisionModelsImports(cst.CSTTransformer): - - def __init__(self): - self.changed = False - - def leave_Import( - self, node: cst.Import, updated_node: cst.Import - ) -> cst.CSTNode: - if len(updated_node.names) == 1: - alias = updated_node.names[0] - if isinstance(alias.name, cst.Attribute) and \ - alias.name.value.value == 'torchvision' and \ - alias.name.attr.value == 'models' and \ - alias.asname and alias.asname.name.value == 'models': - - self.changed = True - new_import = cst.ImportFrom( - module=cst.Name(value='torchvision'), - names=[cst.ImportAlias(name=cst.Name(value='models'))] - ) - return new_import - - return updated_node diff --git a/torchfix/visitors/vision/__init__.py b/torchfix/visitors/vision/__init__.py index 7adcc19..9bc944e 100644 --- a/torchfix/visitors/vision/__init__.py +++ b/torchfix/visitors/vision/__init__.py @@ -1,2 +1,3 @@ from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 +from .models_import import TorchVisionModelsImportVisitor # noqa: F401 diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py new file mode 100644 index 0000000..2f810a3 --- /dev/null +++ b/torchfix/visitors/vision/models_import.py @@ -0,0 +1,40 @@ +import libcst as cst + +from ...common import LintViolation, TorchVisitor + + +class TorchVisionModelsImportVisitor(TorchVisitor): + ERROR_CODE = "TOR203" + + def visit_Import(self, node: cst.Import) -> None: + for imported_item in node.names: + if isinstance(imported_item.name, cst.Attribute): + if ( + isinstance(imported_item.name.value, cst.Name) + and imported_item.name.value.value == "torchvision" + and imported_item.name.attr.value == "models" + and imported_item.asname is not None + and imported_item.asname.name.value == "models" + ): + print(imported_item.asname.name.value) + position = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + # print(position) + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], + ) + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=( + "Consider replacing 'import torchvision.models as" + " models' with 'from torchvision import models'. " + ), + line=position.start.line, + column=position.start.column, + node=node, + replacement=replacement + ) + )