From 9b5adefdf46069d3da617502bc17c006f3eb5880 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 15 Mar 2024 14:28:46 -0700 Subject: [PATCH] Refactor to use add_violation (#34) --- .../visitors/deprecated_symbols/__init__.py | 33 ++------------- .../deprecated_symbols/chain_matmul.py | 2 +- .../visitors/deprecated_symbols/cholesky.py | 2 +- torchfix/visitors/deprecated_symbols/qr.py | 2 +- torchfix/visitors/deprecated_symbols/range.py | 1 + torchfix/visitors/internal/__init__.py | 18 +-------- torchfix/visitors/misc/__init__.py | 40 +++++-------------- torchfix/visitors/performance/__init__.py | 18 ++------- torchfix/visitors/security/__init__.py | 21 +++------- torchfix/visitors/vision/models_import.py | 23 ++++------- torchfix/visitors/vision/pretrained.py | 19 +++------ torchfix/visitors/vision/to_tensor.py | 25 ++++-------- 12 files changed, 51 insertions(+), 153 deletions(-) diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index beb0657..3450777 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -6,7 +6,6 @@ from ...common import ( TorchVisitor, call_with_name_changes, - LintViolation, ) from .range import call_replacement_range @@ -65,9 +64,6 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if isinstance(node.names, Sequence): for name in node.names: qualified_name = f"{module}.{name.name.value}" - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERROR_CODE[3] @@ -76,22 +72,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: error_code = self.ERROR_CODE[2] message = f"Import of removed function {qualified_name}" - reference = self.deprecated_config[qualified_name].get( - "reference" - ) + reference = self.deprecated_config[qualified_name].get("reference") if reference is not None: message = f"{message}: {reference}" - self.violations.append( - LintViolation( - error_code=error_code, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, - ) - ) + self.add_violation(node, error_code=error_code, message=message) def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) @@ -99,9 +84,6 @@ def visit_Call(self, node) -> None: return if qualified_name in self.deprecated_config: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERROR_CODE[1] message = f"Use of deprecated function {qualified_name}" @@ -114,15 +96,8 @@ def visit_Call(self, node) -> None: if reference is not None: message = f"{message}: {reference}" - self.violations.append( - LintViolation( - error_code=error_code, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, error_code=error_code, message=message, replacement=replacement ) diff --git a/torchfix/visitors/deprecated_symbols/chain_matmul.py b/torchfix/visitors/deprecated_symbols/chain_matmul.py index 3eab730..ca546c3 100644 --- a/torchfix/visitors/deprecated_symbols/chain_matmul.py +++ b/torchfix/visitors/deprecated_symbols/chain_matmul.py @@ -20,7 +20,7 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode: replacement_args = [matrices_arg] else: replacement_args = [matrices_arg, out_arg] - module_name = get_module_name(node, 'torch') + module_name = get_module_name(node, "torch") replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)") replacement = replacement.with_changes(args=replacement_args) diff --git a/torchfix/visitors/deprecated_symbols/cholesky.py b/torchfix/visitors/deprecated_symbols/cholesky.py index cec5e71..c44c831 100644 --- a/torchfix/visitors/deprecated_symbols/cholesky.py +++ b/torchfix/visitors/deprecated_symbols/cholesky.py @@ -1,5 +1,5 @@ import libcst as cst -from ...common import (TorchVisitor, get_module_name) +from ...common import TorchVisitor, get_module_name def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode: diff --git a/torchfix/visitors/deprecated_symbols/qr.py b/torchfix/visitors/deprecated_symbols/qr.py index f1d96df..9fc4874 100644 --- a/torchfix/visitors/deprecated_symbols/qr.py +++ b/torchfix/visitors/deprecated_symbols/qr.py @@ -1,6 +1,6 @@ import libcst as cst from typing import Optional -from ...common import (TorchVisitor, get_module_name) +from ...common import TorchVisitor, get_module_name def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]: diff --git a/torchfix/visitors/deprecated_symbols/range.py b/torchfix/visitors/deprecated_symbols/range.py index 97fec69..26f0a4f 100644 --- a/torchfix/visitors/deprecated_symbols/range.py +++ b/torchfix/visitors/deprecated_symbols/range.py @@ -7,6 +7,7 @@ def call_replacement_range(node: cst.Call) -> Optional[cst.Call]: """Replace `range` with `arange`. Add `step` to the `end` argument as `arange` has the interval `[start, end)`. """ + # `torch.range` documented signature is not a valid Python signature, # so it's hard to generalize this. def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py index 424e1f2..14389b3 100644 --- a/torchfix/visitors/internal/__init__.py +++ b/torchfix/visitors/internal/__init__.py @@ -1,5 +1,4 @@ -import libcst as cst -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchScopedLibraryVisitor(TorchVisitor): @@ -17,17 +16,4 @@ class TorchScopedLibraryVisitor(TorchVisitor): def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) if qualified_name == "torch.library.Library": - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, - ) - ) + self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 6ce7c84..ef60b3e 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -2,7 +2,7 @@ import libcst.matchers as m -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchRequireGradVisitor(TorchVisitor): @@ -31,20 +31,11 @@ def visit_Assign(self, node): replacement = node.with_deep_changes( old_node=node.targets[0].target.attr, value="requires_grad" ) - - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) @@ -65,10 +56,6 @@ def visit_Call(self, node): if qualified_name == "torch.utils.checkpoint.checkpoint": use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1) if use_reentrant_arg is None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - # This codemod maybe unsafe correctness-wise # if reentrant behavior is actually needed, # so the changes need to be verified/tested. @@ -76,14 +63,9 @@ def visit_Call(self, node): cst.parse_expression("f(use_reentrant=False)"), cst.Call ).args[0] replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index f838fbe..427eb78 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -1,8 +1,7 @@ -import libcst as cst import libcst.matchers as m -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchSynchronizedDataLoaderVisitor(TorchVisitor): @@ -25,17 +24,6 @@ def visit_Call(self, node): if num_workers_arg is None or m.matches( num_workers_arg.value, m.Integer(value="0") ): - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, - ) + self.add_violation( + node, error_code=self.ERROR_CODE, message=self.MESSAGE ) diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index 010c5f4..5dfdf6e 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -1,5 +1,5 @@ import libcst as cst -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchUnsafeLoadVisitor(TorchVisitor): @@ -21,10 +21,6 @@ def visit_Call(self, node): if qualified_name == "torch.load": weights_only_arg = self.get_specific_arg(node, "weights_only", -1) if weights_only_arg is None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - # Add `weights_only=True` if there is no `pickle_module`. # (do not add `weights_only=False` with `pickle_module`, as it # needs to be an explicit choice). @@ -42,14 +38,9 @@ def visit_Call(self, node): replacement = node.with_changes( args=node.args + (weights_only_arg,) ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index 7ccbebb..928f2c1 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -1,6 +1,6 @@ import libcst as cst -from ...common import LintViolation, TorchVisitor +from ...common import TorchVisitor class TorchVisionModelsImportVisitor(TorchVisitor): @@ -24,23 +24,16 @@ def visit_Import(self, node: cst.Import) -> None: and isinstance(imported_item.asname.name, cst.Name) and imported_item.asname.name.value == "models" ): - position = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) # Replace only if the import statement has no other names if len(node.names) == 1: replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position.start.line, - column=position.start.column, - node=node, - replacement=replacement + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) break diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index 6e17048..99dd845 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import LintViolation, TorchVisitor +from ...common import TorchVisitor class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): @@ -248,16 +248,9 @@ def _new_arg_and_import( node.with_changes(args=replacement_args) if has_replacement else None ) if message is not None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=message, + replacement=replacement, ) diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index ab15827..3395dd9 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -1,32 +1,21 @@ from collections.abc import Sequence import libcst as cst -from ...common import LintViolation, TorchVisitor +from ...common import TorchVisitor class TorchVisionDeprecatedToTensorVisitor(TorchVisitor): ERROR_CODE = "TOR202" + MESSAGE = ( + "The transform `v2.ToTensor()` is deprecated and will be removed " + "in a future release. Instead, please use " + "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 + ) def _maybe_add_violation(self, qualified_name, node): if qualified_name != "torchvision.transforms.v2.ToTensor": return - position = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=( - "The transform `v2.ToTensor()` is deprecated and will be removed " - "in a future release. Instead, please use " - "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 - ), - line=position.start.line, - column=position.start.column, - node=node, - replacement=None, - ) - ) + self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) def visit_ImportFrom(self, node): module_path = cst.helpers.get_absolute_module_from_package_for_import(