Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warnings #1382

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flash_attn/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
# Reshaping input to 3D tensor (attn_batches, sq, sk)
inputs = inputs.view(-1, sq, sk)
args = _cast_if_autocast_enabled(inputs, scale)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(device_type="cuda", enabled=False):
probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
return probs.view(b, np, sq, sk)


# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.amp.custom_fwd`.
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
# So I needed to manually write two `torch.autograd.Function` inheritances.
# Fused operation which performs following three operations in sequence
Expand All @@ -88,7 +88,7 @@ def backward(ctx, output_grads):
def scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(device_type="cuda", enabled=False):
return ScaledMaskedSoftmax.apply(*args)


Expand Down
10 changes: 5 additions & 5 deletions flash_attn/ops/fused_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup

from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
Expand All @@ -26,7 +26,7 @@

class FusedDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(
ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True
):
Expand Down Expand Up @@ -67,7 +67,7 @@ def forward(
return output if not return_residual else (output, x)

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
Expand Down Expand Up @@ -248,7 +248,7 @@ def forward(self, x):

class FusedMLPFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(
ctx,
x,
Expand Down Expand Up @@ -345,7 +345,7 @@ def forward(
return output2 if not return_residual else (output2, x)

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
Expand Down
6 changes: 3 additions & 3 deletions flash_attn/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.amp import custom_fwd, custom_bwd

import triton
import triton.language as tl
Expand Down Expand Up @@ -981,7 +981,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):

class LayerNormLinearFn(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(
ctx,
x,
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def forward(
return out if not prenorm else (out, residual_out.reshape(x_shape_og))

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
Expand Down
6 changes: 3 additions & 3 deletions flash_attn/ops/triton/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd

from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd
from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act


class FusedDenseSqreluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
return output2.reshape(*batch_shape, output2.shape[-1])

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
Expand Down