Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to use add_violation #34

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...common import (
TorchVisitor,
call_with_name_changes,
LintViolation,
)

from .range import call_replacement_range
Expand Down Expand Up @@ -65,9 +64,6 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if isinstance(node.names, Sequence):
for name in node.names:
qualified_name = f"{module}.{name.name.value}"
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[3]
Expand All @@ -76,32 +72,18 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
error_code = self.ERROR_CODE[2]
message = f"Import of removed function {qualified_name}"

reference = self.deprecated_config[qualified_name].get(
"reference"
)
reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"

self.violations.append(
LintViolation(
error_code=error_code,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
)
self.add_violation(node, error_code=error_code, message=message)

def visit_Call(self, node) -> None:
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return

if qualified_name in self.deprecated_config:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[1]
message = f"Use of deprecated function {qualified_name}"
Expand All @@ -114,15 +96,8 @@ def visit_Call(self, node) -> None:
if reference is not None:
message = f"{message}: {reference}"

self.violations.append(
LintViolation(
error_code=error_code,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node, error_code=error_code, message=message, replacement=replacement
)


Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/deprecated_symbols/chain_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode:
replacement_args = [matrices_arg]
else:
replacement_args = [matrices_arg, out_arg]
module_name = get_module_name(node, 'torch')
module_name = get_module_name(node, "torch")
replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)")
replacement = replacement.with_changes(args=replacement_args)

Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/deprecated_symbols/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import libcst as cst
from ...common import (TorchVisitor, get_module_name)
from ...common import TorchVisitor, get_module_name


def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode:
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/deprecated_symbols/qr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import libcst as cst
from typing import Optional
from ...common import (TorchVisitor, get_module_name)
from ...common import TorchVisitor, get_module_name


def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]:
Expand Down
1 change: 1 addition & 0 deletions torchfix/visitors/deprecated_symbols/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def call_replacement_range(node: cst.Call) -> Optional[cst.Call]:
"""Replace `range` with `arange`.
Add `step` to the `end` argument as `arange` has the interval `[start, end)`.
"""

# `torch.range` documented signature is not a valid Python signature,
# so it's hard to generalize this.
def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]:
Expand Down
18 changes: 2 additions & 16 deletions torchfix/visitors/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import libcst as cst
from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchScopedLibraryVisitor(TorchVisitor):
Expand All @@ -17,17 +16,4 @@ class TorchScopedLibraryVisitor(TorchVisitor):
def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.library.Library":
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
)
self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE)
40 changes: 11 additions & 29 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import libcst.matchers as m


from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchRequireGradVisitor(TorchVisitor):
Expand Down Expand Up @@ -31,20 +31,11 @@ def visit_Assign(self, node):
replacement = node.with_deep_changes(
old_node=node.targets[0].target.attr, value="requires_grad"
)

position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)


Expand All @@ -65,25 +56,16 @@ def visit_Call(self, node):
if qualified_name == "torch.utils.checkpoint.checkpoint":
use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1)
if use_reentrant_arg is None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

# This codemod maybe unsafe correctness-wise
# if reentrant behavior is actually needed,
# so the changes need to be verified/tested.
use_reentrant_arg = cst.ensure_type(
cst.parse_expression("f(use_reentrant=False)"), cst.Call
).args[0]
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
18 changes: 3 additions & 15 deletions torchfix/visitors/performance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import libcst as cst
import libcst.matchers as m


from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchSynchronizedDataLoaderVisitor(TorchVisitor):
Expand All @@ -25,17 +24,6 @@ def visit_Call(self, node):
if num_workers_arg is None or m.matches(
num_workers_arg.value, m.Integer(value="0")
):
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
self.add_violation(
node, error_code=self.ERROR_CODE, message=self.MESSAGE
)
21 changes: 6 additions & 15 deletions torchfix/visitors/security/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import libcst as cst
from ...common import TorchVisitor, LintViolation
from ...common import TorchVisitor


class TorchUnsafeLoadVisitor(TorchVisitor):
Expand All @@ -21,10 +21,6 @@ def visit_Call(self, node):
if qualified_name == "torch.load":
weights_only_arg = self.get_specific_arg(node, "weights_only", -1)
if weights_only_arg is None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

# Add `weights_only=True` if there is no `pickle_module`.
# (do not add `weights_only=False` with `pickle_module`, as it
# needs to be an explicit choice).
Expand All @@ -42,14 +38,9 @@ def visit_Call(self, node):
replacement = node.with_changes(
args=node.args + (weights_only_arg,)
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
23 changes: 8 additions & 15 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import libcst as cst

from ...common import LintViolation, TorchVisitor
from ...common import TorchVisitor


class TorchVisionModelsImportVisitor(TorchVisitor):
Expand All @@ -24,23 +24,16 @@ def visit_Import(self, node: cst.Import) -> None:
and isinstance(imported_item.asname.name, cst.Name)
and imported_item.asname.name.value == "models"
):
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position.start.line,
column=position.start.column,
node=node,
replacement=replacement
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
break
19 changes: 6 additions & 13 deletions torchfix/visitors/vision/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import libcst as cst
from libcst.codemod.visitors import ImportItem

from ...common import LintViolation, TorchVisitor
from ...common import TorchVisitor


class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor):
Expand Down Expand Up @@ -248,16 +248,9 @@ def _new_arg_and_import(
node.with_changes(args=replacement_args) if has_replacement else None
)
if message is not None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=message,
replacement=replacement,
)
25 changes: 7 additions & 18 deletions torchfix/visitors/vision/to_tensor.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
from collections.abc import Sequence
import libcst as cst

from ...common import LintViolation, TorchVisitor
from ...common import TorchVisitor


class TorchVisionDeprecatedToTensorVisitor(TorchVisitor):
ERROR_CODE = "TOR202"
MESSAGE = (
"The transform `v2.ToTensor()` is deprecated and will be removed "
"in a future release. Instead, please use "
"`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501
)

def _maybe_add_violation(self, qualified_name, node):
if qualified_name != "torchvision.transforms.v2.ToTensor":
return
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=(
"The transform `v2.ToTensor()` is deprecated and will be removed "
"in a future release. Instead, please use "
"`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501
),
line=position.start.line,
column=position.start.column,
node=node,
replacement=None,
)
)
self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE)

def visit_ImportFrom(self, node):
module_path = cst.helpers.get_absolute_module_from_package_for_import(
Expand Down
Loading