diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index 53ebb340..ec5b567b 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,3 +1,10 @@ +import torch +from packaging import version + +if version.parse(torch.__version__) >= version.parse('2.0.0'): + from einops._torch_specific import allow_ops_in_compiled_graph + allow_ops_in_compiled_graph() + from dalle2_pytorch.version import __version__ from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e4f2ad49..654b4b95 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.14.0' +__version__ = '1.14.2' diff --git a/setup.py b/setup.py index cf94e635..a8caad01 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ 'clip-anytorch>=2.5.2', 'coca-pytorch>=0.0.5', 'ema-pytorch>=0.0.7', - 'einops>=0.6', + 'einops>=0.6.1', 'embedding-reader', 'kornia>=0.5.4', 'numpy',