Skip to content

Commit

Permalink
Remove more Torch version comparisons
Browse files Browse the repository at this point in the history
Follows up on ed09971
  • Loading branch information
akx committed Sep 5, 2024
1 parent 1659a1c commit 60ded62
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 69 deletions.
14 changes: 6 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@ def str2bool(v):
default=False, # TODO: later default to True
help="log to wandb",
)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="single checkpoint file to resume from",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="single checkpoint file to resume from",
)
default_args = default_trainer_args()
for key in default_args:
parser.add_argument("--" + key, default=default_args[key])
Expand Down Expand Up @@ -618,7 +617,6 @@ def init_wandb(save_dir, opt, config, group_name, name_str):

# move before model init, in case a torch.compile(...) is called somewhere
if opt.enable_tf32:
# pt_version = version.parse(torch.__version__)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print(f"Enabling TF32 for PyTorch {torch.__version__}")
Expand Down
4 changes: 1 addition & 3 deletions sgm/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import torch.nn as nn
from einops import rearrange
from packaging import version

from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.ema import LitEma
Expand Down Expand Up @@ -43,8 +42,7 @@ def __init__(
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
self.automatic_optimization = False # pytorch lightning

def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
Expand Down
66 changes: 25 additions & 41 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,31 @@
from packaging import version
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.backends.cuda import SDPBackend, sdp_kernel

logpy = logging.getLogger(__name__)

if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel

BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
from contextlib import nullcontext

SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
logpy.warn(
f"No SDP backend available, likely because you are running in pytorch "
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
f"You might want to consider upgrading."
)
SDP_IS_AVAILABLE = True


BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}

try:
import xformers
Expand Down Expand Up @@ -476,10 +465,8 @@ def __init__(
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
logpy.warn(
f"Attention mode '{attn_mode}' is not available. Falling "
f"back to native attention. This is not a problem in "
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
f"version {torch.__version__}."
f"Attention mode '{attn_mode}' is not available. "
f"Falling back to native attention."
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
Expand All @@ -495,10 +482,7 @@ def __init__(
logpy.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
Expand Down
6 changes: 2 additions & 4 deletions sgm/modules/autoencoding/temporal_ae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable, Iterable, Union

import torch
Expand Down Expand Up @@ -260,10 +261,7 @@ def make_time_attn(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
warnings.warn(f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention.")
attn_type = "vanilla"

if attn_type == "vanilla":
Expand Down
7 changes: 2 additions & 5 deletions sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,9 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
"linear",
"none",
], f"attn_type {attn_type} unknown"
if (
version.parse(torch.__version__) < version.parse("2.0.0")
and attn_type != "none"
):
if attn_type != "none":
assert XFORMERS_IS_AVAILABLE, (
f"We do not support vanilla attention in {torch.__version__} anymore, "
f"We do not support vanilla attention anymore, "
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
Expand Down
12 changes: 4 additions & 8 deletions sgm/modules/diffusionmodules/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import torch
import torch.nn as nn
from packaging import version

OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"


class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False):
super().__init__()
compile = (
torch.compile
if (version.parse(torch.__version__) >= version.parse("2.0.0"))
and compile_model
else lambda x: x
)
self.diffusion_model = compile(diffusion_model)
if compile_model:
self.diffusion_model = torch.compile(diffusion_model)
else:
self.diffusion_model = diffusion_model

def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
Expand Down

0 comments on commit 60ded62

Please sign in to comment.