Skip to content

Commit

Permalink
Add TorchLogsumexpVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 committed Jan 6, 2025
1 parent 4ff3caf commit d70fc51
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 2 deletions.
11 changes: 11 additions & 0 deletions tests/fixtures/misc/checker/logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
a = torch.randn(5)
b = torch.randn(5)

# logsumexp
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True))
y = torch.log(torch.sum(torch.exp(2.5 + x), 1))

# not logsumexp
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5)
y = torch.log(torch.sum(torch.exp(x) + 2.5, 1))
2 changes: 2 additions & 0 deletions tests/fixtures/misc/checker/logsumexp.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
6:5 TOR108 Use numerically stabilized `torch.logsumexp`.
7:5 TOR108 Use numerically stabilized `torch.logsumexp`.
1 change: 1 addition & 0 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def pytest_generate_tests(metafunc):
"TOR105",
"TOR106",
"TOR107",
"TOR108",
},
),
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
Expand Down
2 changes: 2 additions & 0 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TorchDeprecatedSymbolsVisitor,
TorchExpm1Visitor,
TorchLog1pVisitor,
TorchLogsumexpVisitor,
TorchNonPublicAliasVisitor,
TorchReentrantCheckpointVisitor,
TorchRequireGradVisitor,
Expand All @@ -32,6 +33,7 @@
TorchDeprecatedSymbolsVisitor,
TorchExpm1Visitor,
TorchLog1pVisitor,
TorchLogsumexpVisitor,
TorchNonPublicAliasVisitor,
TorchRequireGradVisitor,
TorchReentrantCheckpointVisitor,
Expand Down
2 changes: 2 additions & 0 deletions torchfix/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .misc import (
TorchExpm1Visitor,
TorchLog1pVisitor,
TorchLogsumexpVisitor,
TorchReentrantCheckpointVisitor,
TorchRequireGradVisitor,
)
Expand All @@ -19,6 +20,7 @@
"TorchDeprecatedSymbolsVisitor",
"TorchExpm1Visitor",
"TorchLog1pVisitor",
"TorchLogsumexpVisitor",
"TorchNonPublicAliasVisitor",
"TorchReentrantCheckpointVisitor",
"TorchRequireGradVisitor",
Expand Down
40 changes: 38 additions & 2 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class TorchLog1pVisitor(TorchVisitor):

def visit_Call(self, node):
if self.get_qualified_name_for_call(node) == "torch.log":

if m.matches(
node,
m.Call(
Expand All @@ -114,7 +113,6 @@ def visit_Call(self, node):
],
),
):

self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
Expand Down Expand Up @@ -154,3 +152,41 @@ def visit_BinaryOperation(self, node):
message=self.ERRORS[0].message(),
replacement=None,
)


class TorchLogsumexpVisitor(TorchVisitor):
"""
Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`.
"""

ERRORS = [
TorchError(
"TOR108",
("Use numerically stabilized `torch.logsumexp`."),
)
]

def visit_Call(self, node):
if self.get_qualified_name_for_call(node) == "torch.log":
if m.matches(
node,
m.Call(
args=[
m.Arg(m.Call(args=[m.Arg(m.Call()), m.ZeroOrMore()])),
m.ZeroOrMore(),
]
),
):
if self.get_qualified_name_for_call(node.args[0].value) == "torch.sum":
if (
self.get_qualified_name_for_call(
node.args[0].value.args[0].value
)
== "torch.exp"
):
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=None,
)

0 comments on commit d70fc51

Please sign in to comment.