diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py new file mode 100644 index 0000000..c24374f --- /dev/null +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -0,0 +1,11 @@ +import torch +a = torch.randn(5) +b = torch.randn(5) + +# logsumexp +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) + +# not logsumexp +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) +y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt new file mode 100644 index 0000000..4a4f5ec --- /dev/null +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -0,0 +1,2 @@ +6:5 TOR108 Use numerically stabilized `torch.logsumexp`. +7:5 TOR108 Use numerically stabilized `torch.logsumexp`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index a31bc1b..5baa12a 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -47,6 +47,7 @@ def pytest_generate_tests(metafunc): "TOR105", "TOR106", "TOR107", + "TOR108", }, ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 1cb3e69..dae1a24 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -11,6 +11,7 @@ TorchDeprecatedSymbolsVisitor, TorchExpm1Visitor, TorchLog1pVisitor, + TorchLogsumexpVisitor, TorchNonPublicAliasVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, @@ -32,6 +33,7 @@ TorchDeprecatedSymbolsVisitor, TorchExpm1Visitor, TorchLog1pVisitor, + TorchLogsumexpVisitor, TorchNonPublicAliasVisitor, TorchRequireGradVisitor, TorchReentrantCheckpointVisitor, diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 8e56b4a..5317d1b 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -3,6 +3,7 @@ from .misc import ( TorchExpm1Visitor, TorchLog1pVisitor, + TorchLogsumexpVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, ) @@ -19,6 +20,7 @@ "TorchDeprecatedSymbolsVisitor", "TorchExpm1Visitor", "TorchLog1pVisitor", + "TorchLogsumexpVisitor", "TorchNonPublicAliasVisitor", "TorchReentrantCheckpointVisitor", "TorchRequireGradVisitor", diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 348612c..e77de4f 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -96,7 +96,6 @@ class TorchLog1pVisitor(TorchVisitor): def visit_Call(self, node): if self.get_qualified_name_for_call(node) == "torch.log": - if m.matches( node, m.Call( @@ -114,7 +113,6 @@ def visit_Call(self, node): ], ), ): - self.add_violation( node, error_code=self.ERRORS[0].error_code, @@ -154,3 +152,41 @@ def visit_BinaryOperation(self, node): message=self.ERRORS[0].message(), replacement=None, ) + + +class TorchLogsumexpVisitor(TorchVisitor): + """ + Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`. + """ + + ERRORS = [ + TorchError( + "TOR108", + ("Use numerically stabilized `torch.logsumexp`."), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call(node) == "torch.log": + if m.matches( + node, + m.Call( + args=[ + m.Arg(m.Call(args=[m.Arg(m.Call()), m.ZeroOrMore()])), + m.ZeroOrMore(), + ] + ), + ): + if self.get_qualified_name_for_call(node.args[0].value) == "torch.sum": + if ( + self.get_qualified_name_for_call( + node.args[0].value.args[0].value + ) + == "torch.exp" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + )