Skip to content

Commit

Permalink
Add rules for deprecated AMP APIs (#87)
Browse files Browse the repository at this point in the history
Add codemods for `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast`,
and checkers for `torch.cuda.amp.custom_fwd` and
`torch.cuda.amp.custom_bwd`.
  • Loading branch information
kit1980 authored Dec 13, 2024
1 parent 86186f4 commit 4ff3caf
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 6 deletions.
10 changes: 10 additions & 0 deletions tests/fixtures/deprecated_symbols/checker/amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

torch.cuda.amp.autocast()
torch.cuda.amp.custom_fwd()
torch.cuda.amp.custom_bwd()

dtype = torch.float32
maybe_autocast = torch.cpu.amp.autocast()
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
maybe_autocast = torch.cpu.amp.autocast(dtype=dtype)
6 changes: 6 additions & 0 deletions tests/fixtures/deprecated_symbols/checker/amp.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast
4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd
5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd
8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
11 changes: 11 additions & 0 deletions tests/fixtures/deprecated_symbols/codemod/amp.in.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

dtype = torch.float32

maybe_autocast = torch.cuda.amp.autocast()
maybe_autocast = torch.cuda.amp.autocast(dtype=torch.bfloat16)
maybe_autocast = torch.cuda.amp.autocast(dtype=dtype)

maybe_autocast = torch.cpu.amp.autocast()
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
maybe_autocast = torch.cpu.amp.autocast(dtype=dtype)
11 changes: 11 additions & 0 deletions tests/fixtures/deprecated_symbols/codemod/amp.out.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

dtype = torch.float32

maybe_autocast = torch.amp.autocast("cuda")
maybe_autocast = torch.amp.autocast("cuda", dtype=torch.bfloat16)
maybe_autocast = torch.amp.autocast("cuda", dtype=dtype)

maybe_autocast = torch.amp.autocast("cpu")
maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16)
maybe_autocast = torch.amp.autocast("cpu", dtype=dtype)
16 changes: 16 additions & 0 deletions torchfix/deprecated_symbols.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@
remove_pr:
reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel

- name: torch.cuda.amp.autocast
deprecate_pr: TBA
remove_pr:

- name: torch.cuda.amp.custom_fwd
deprecate_pr: TBA
remove_pr:

- name: torch.cuda.amp.custom_bwd
deprecate_pr: TBA
remove_pr:

- name: torch.cpu.amp.autocast
deprecate_pr: TBA
remove_pr:

# functorch
- name: functorch.vmap
deprecate_pr: TBA
Expand Down
17 changes: 11 additions & 6 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import libcst as cst
import pkgutil
from typing import List, Optional

import libcst as cst
import yaml
from typing import Optional, List

from ...common import (
TorchVisitor,
TorchError,
call_with_name_changes,
check_old_names_in_import_from,
TorchError,
TorchVisitor,
)

from .range import call_replacement_range
from .cholesky import call_replacement_cholesky
from .amp import call_replacement_cpu_amp_autocast, call_replacement_cuda_amp_autocast
from .chain_matmul import call_replacement_chain_matmul
from .cholesky import call_replacement_cholesky
from .qr import call_replacement_qr

from .range import call_replacement_range


class TorchDeprecatedSymbolsVisitor(TorchVisitor):
ERRORS: List[TorchError] = [
Expand Down Expand Up @@ -49,6 +52,8 @@ def _call_replacement(
"torch.range": call_replacement_range,
"torch.chain_matmul": call_replacement_chain_matmul,
"torch.qr": call_replacement_qr,
"torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast,
"torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast,
}
replacement = None

Expand Down
26 changes: 26 additions & 0 deletions torchfix/visitors/deprecated_symbols/amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import libcst as cst

from ...common import get_module_name


def call_replacement_cpu_amp_autocast(node: cst.Call) -> cst.CSTNode:
return _call_replacement_amp(node, "cpu")


def call_replacement_cuda_amp_autocast(node: cst.Call) -> cst.CSTNode:
return _call_replacement_amp(node, "cuda")


def _call_replacement_amp(node: cst.Call, device: str) -> cst.CSTNode:
"""
Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and
Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`.
"""
device_arg = cst.ensure_type(cst.parse_expression(f'f("{device}")'), cst.Call).args[
0
]

module_name = get_module_name(node, "torch")
replacement = cst.parse_expression(f"{module_name}.amp.autocast(args)")
replacement = replacement.with_changes(args=(device_arg, *node.args))
return replacement

0 comments on commit 4ff3caf

Please sign in to comment.