Skip to content

Commit

Permalink
Refactor to use add_violation (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 authored Mar 15, 2024
1 parent e782670 commit 9b5adef
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 153 deletions.
33 changes: 4 additions & 29 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...common import (
TorchVisitor,
call_with_name_changes,
LintViolation,
)

from .range import call_replacement_range
Expand Down Expand Up @@ -65,9 +64,6 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if isinstance(node.names, Sequence):
for name in node.names:
qualified_name = f"{module}.{name.name.value}"
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[3]
Expand All @@ -76,32 +72,18 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
error_code = self.ERROR_CODE[2]
message = f"Import of removed function {qualified_name}"

reference = self.deprecated_config[qualified_name].get(
"reference"
)
reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"

self.violations.append(
LintViolation(
error_code=error_code,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
)
self.add_violation(node, error_code=error_code, message=message)

def visit_Call(self, node) -> None:
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return

if qualified_name in self.deprecated_config:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[1]
message = f"Use of deprecated function {qualified_name}"
Expand All @@ -114,15 +96,8 @@ def visit_Call(self, node) -> None:
if reference is not None:
message = f"{message}: {reference}"

self.violations.append(
LintViolation(
error_code=error_code,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node, error_code=error_code, message=message, replacement=replacement
)


Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/deprecated_symbols/chain_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode:
replacement_args = [matrices_arg]
else:
replacement_args = [matrices_arg, out_arg]
module_name = get_module_name(node, 'torch')
module_name = get_module_name(node, "torch")
replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)")
replacement = replacement.with_changes(args=replacement_args)

Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/deprecated_symbols/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import libcst as cst
from ...common import (TorchVisitor, get_module_name)
from ...common import TorchVisitor, get_module_name


def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode:
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/deprecated_symbols/qr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import libcst as cst
from typing import Optional
from ...common import (TorchVisitor, get_module_name)
from ...common import TorchVisitor, get_module_name


def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]:
Expand Down
1 change: 1 addition & 0 deletions torchfix/visitors/deprecated_symbols/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def call_replacement_range(node: cst.Call) -> Optional[cst.Call]:
"""Replace `range` with `arange`.
Add `step` to the `end` argument as `arange` has the interval `[start, end)`.
"""

# `torch.range` documented signature is not a valid Python signature,
# so it's hard to generalize this.
def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]:
Expand Down
18 changes: 2 additions & 16 deletions torchfix/visitors/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import libcst as cst
from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchScopedLibraryVisitor(TorchVisitor):
Expand All @@ -17,17 +16,4 @@ class TorchScopedLibraryVisitor(TorchVisitor):
def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.library.Library":
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
)
self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE)
40 changes: 11 additions & 29 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import libcst.matchers as m


from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchRequireGradVisitor(TorchVisitor):
Expand Down Expand Up @@ -31,20 +31,11 @@ def visit_Assign(self, node):
replacement = node.with_deep_changes(
old_node=node.targets[0].target.attr, value="requires_grad"
)

position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)


Expand All @@ -65,25 +56,16 @@ def visit_Call(self, node):
if qualified_name == "torch.utils.checkpoint.checkpoint":
use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1)
if use_reentrant_arg is None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

# This codemod maybe unsafe correctness-wise
# if reentrant behavior is actually needed,
# so the changes need to be verified/tested.
use_reentrant_arg = cst.ensure_type(
cst.parse_expression("f(use_reentrant=False)"), cst.Call
).args[0]
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
18 changes: 3 additions & 15 deletions torchfix/visitors/performance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import libcst as cst
import libcst.matchers as m


from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchSynchronizedDataLoaderVisitor(TorchVisitor):
Expand All @@ -25,17 +24,6 @@ def visit_Call(self, node):
if num_workers_arg is None or m.matches(
num_workers_arg.value, m.Integer(value="0")
):
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
self.add_violation(
node, error_code=self.ERROR_CODE, message=self.MESSAGE
)
21 changes: 6 additions & 15 deletions torchfix/visitors/security/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import libcst as cst
from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchUnsafeLoadVisitor(TorchVisitor):
Expand All @@ -21,10 +21,6 @@ def visit_Call(self, node):
if qualified_name == "torch.load":
weights_only_arg = self.get_specific_arg(node, "weights_only", -1)
if weights_only_arg is None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

# Add `weights_only=True` if there is no `pickle_module`.
# (do not add `weights_only=False` with `pickle_module`, as it
# needs to be an explicit choice).
Expand All @@ -42,14 +38,9 @@ def visit_Call(self, node):
replacement = node.with_changes(
args=node.args + (weights_only_arg,)
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
23 changes: 8 additions & 15 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import libcst as cst

from ...common import LintViolation, TorchVisitor
from ...common import TorchVisitor


class TorchVisionModelsImportVisitor(TorchVisitor):
Expand All @@ -24,23 +24,16 @@ def visit_Import(self, node: cst.Import) -> None:
and isinstance(imported_item.asname.name, cst.Name)
and imported_item.asname.name.value == "models"
):
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position.start.line,
column=position.start.column,
node=node,
replacement=replacement
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
break
19 changes: 6 additions & 13 deletions torchfix/visitors/vision/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import libcst as cst
from libcst.codemod.visitors import ImportItem

from ...common import LintViolation, TorchVisitor
from ...common import TorchVisitor


class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor):
Expand Down Expand Up @@ -248,16 +248,9 @@ def _new_arg_and_import(
node.with_changes(args=replacement_args) if has_replacement else None
)
if message is not None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=message,
replacement=replacement,
)
25 changes: 7 additions & 18 deletions torchfix/visitors/vision/to_tensor.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
from collections.abc import Sequence
import libcst as cst

from ...common import LintViolation, TorchVisitor
from ...common import TorchVisitor


class TorchVisionDeprecatedToTensorVisitor(TorchVisitor):
ERROR_CODE = "TOR202"
MESSAGE = (
"The transform `v2.ToTensor()` is deprecated and will be removed "
"in a future release. Instead, please use "
"`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501
)

def _maybe_add_violation(self, qualified_name, node):
if qualified_name != "torchvision.transforms.v2.ToTensor":
return
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=(
"The transform `v2.ToTensor()` is deprecated and will be removed "
"in a future release. Instead, please use "
"`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501
),
line=position.start.line,
column=position.start.column,
node=node,
replacement=None,
)
)
self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE)

def visit_ImportFrom(self, node):
module_path = cst.helpers.get_absolute_module_from_package_for_import(
Expand Down

0 comments on commit 9b5adef

Please sign in to comment.