Skip to content

Commit

Permalink
Add rules for importing deprecated and removed symbols (#32)
Browse files Browse the repository at this point in the history
* Add rules for importing deprecated symbols

* Add test for functorch

* Appease mypy
  • Loading branch information
kit1980 authored Mar 15, 2024
1 parent 42bb4a8 commit e782670
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 3 deletions.
2 changes: 2 additions & 0 deletions tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
2:7 TOR101 Use of deprecated function torch.qr
6:7 TOR101 Use of deprecated function torch.qr
9:1 TOR103 Import of deprecated function torch.qr
10:7 TOR101 Use of deprecated function torch.qr
13:1 TOR103 Import of deprecated function torch.qr
16:7 TOR101 Use of deprecated function torch.qr
2 changes: 2 additions & 0 deletions tests/fixtures/deprecated_symbols/checker/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

# Check that we get only one warning for the line
functorch.vmap(tdmodule, (None, 0))(td, params)

from functorch import vmap, jacrev
2 changes: 2 additions & 0 deletions tests/fixtures/deprecated_symbols/checker/functorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
4:1 TOR101 Use of deprecated function functorch.vmap
6:1 TOR103 Import of deprecated function functorch.vmap
6:1 TOR103 Import of deprecated function functorch.jacrev
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
2:1 TOR004 Import of removed function torch.symeig
4:8 TOR001 Use of removed function torch.symeig
5:8 TOR001 Use of removed function torch.symeig
2 changes: 1 addition & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_parse_error_code_str():
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR104", "TOR105"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
(None, GET_ALL_ERROR_CODES() - exclude_set),
]
for case, expected in cases:
Expand Down
40 changes: 38 additions & 2 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class TorchDeprecatedSymbolsVisitor(TorchVisitor):
ERROR_CODE = ["TOR001", "TOR101"]
ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"]

def __init__(self, deprecated_config_path=None):
def read_deprecated_config(path=None):
Expand Down Expand Up @@ -57,7 +57,43 @@ def _call_replacement(
self.needed_imports.update(imports)
return replacement

def visit_Call(self, node):
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if node.module is None:
return

module = cst.helpers.get_full_name_for_node(node.module)
if isinstance(node.names, Sequence):
for name in node.names:
qualified_name = f"{module}.{name.name.value}"
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[3]
message = f"Import of deprecated function {qualified_name}"
else:
error_code = self.ERROR_CODE[2]
message = f"Import of removed function {qualified_name}"

reference = self.deprecated_config[qualified_name].get(
"reference"
)
if reference is not None:
message = f"{message}: {reference}"

self.violations.append(
LintViolation(
error_code=error_code,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
)

def visit_Call(self, node) -> None:
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return
Expand Down

0 comments on commit e782670

Please sign in to comment.