diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt b/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt index 9768b1f..3da6ee5 100644 --- a/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt @@ -1,4 +1,6 @@ 2:7 TOR101 Use of deprecated function torch.qr 6:7 TOR101 Use of deprecated function torch.qr +9:1 TOR103 Import of deprecated function torch.qr 10:7 TOR101 Use of deprecated function torch.qr +13:1 TOR103 Import of deprecated function torch.qr 16:7 TOR101 Use of deprecated function torch.qr diff --git a/tests/fixtures/deprecated_symbols/checker/functorch.py b/tests/fixtures/deprecated_symbols/checker/functorch.py index f072cbb..044240f 100644 --- a/tests/fixtures/deprecated_symbols/checker/functorch.py +++ b/tests/fixtures/deprecated_symbols/checker/functorch.py @@ -2,3 +2,5 @@ # Check that we get only one warning for the line functorch.vmap(tdmodule, (None, 0))(td, params) + +from functorch import vmap, jacrev diff --git a/tests/fixtures/deprecated_symbols/checker/functorch.txt b/tests/fixtures/deprecated_symbols/checker/functorch.txt index 336c7ae..e4f802c 100644 --- a/tests/fixtures/deprecated_symbols/checker/functorch.txt +++ b/tests/fixtures/deprecated_symbols/checker/functorch.txt @@ -1 +1,3 @@ 4:1 TOR101 Use of deprecated function functorch.vmap +6:1 TOR103 Import of deprecated function functorch.vmap +6:1 TOR103 Import of deprecated function functorch.jacrev diff --git a/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt b/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt index 06b13e7..7610c28 100644 --- a/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt +++ b/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt @@ -1,2 +1,3 @@ +2:1 TOR004 Import of removed function torch.symeig 4:8 TOR001 Use of removed function torch.symeig 5:8 TOR001 Use of removed function torch.symeig diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 7b9a051..d699ea7 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -75,7 +75,7 @@ def test_parse_error_code_str(): ("ALL,TOR102", GET_ALL_ERROR_CODES()), ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR104", "TOR105"}), + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), (None, GET_ALL_ERROR_CODES() - exclude_set), ] for case, expected in cases: diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 93a9082..beb0657 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -16,7 +16,7 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): - ERROR_CODE = ["TOR001", "TOR101"] + ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"] def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): @@ -57,7 +57,43 @@ def _call_replacement( self.needed_imports.update(imports) return replacement - def visit_Call(self, node): + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + if node.module is None: + return + + module = cst.helpers.get_full_name_for_node(node.module) + if isinstance(node.names, Sequence): + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + if qualified_name in self.deprecated_config: + if self.deprecated_config[qualified_name]["remove_pr"] is None: + error_code = self.ERROR_CODE[3] + message = f"Import of deprecated function {qualified_name}" + else: + error_code = self.ERROR_CODE[2] + message = f"Import of removed function {qualified_name}" + + reference = self.deprecated_config[qualified_name].get( + "reference" + ) + if reference is not None: + message = f"{message}: {reference}" + + self.violations.append( + LintViolation( + error_code=error_code, + message=message, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=None, + ) + ) + + def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) if qualified_name is None: return