diff --git a/.release b/.release index adb070518..e308a340c 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v22.4.1 +v22.5.0 diff --git a/README.md b/README.md index dc6925941..bddda2d54 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,6 @@ The GUI allows you to set the training parameters and generate and run the requi - [Kohya's GUI](#kohyas-gui) - [Table of Contents](#table-of-contents) - - [Tutorials](#tutorials) - - [About SDXL training](#about-sdxl-training) - - [Tips for SDXL training](#tips-for-sdxl-training) - [🦒 Colab](#-colab) - [Installation](#installation) - [Windows](#windows) @@ -45,159 +42,19 @@ The GUI allows you to set the training parameters and generate and run the requi - [SDXL training](#sdxl-training) - [Training scripts for SDXL](#training-scripts-for-sdxl) - [Utility scripts for SDXL](#utility-scripts-for-sdxl) - - [Tips for SDXL training](#tips-for-sdxl-training-1) + - [Tips for SDXL training](#tips-for-sdxl-training) - [Format of Textual Inversion embeddings for SDXL](#format-of-textual-inversion-embeddings-for-sdxl) - [ControlNet-LLLite](#controlnet-lllite) - [Sample image generation during training](#sample-image-generation-during-training-1) - [Change History](#change-history) - - -## Tutorials - -[How to Create a LoRA Part 1: Dataset Preparation](https://www.youtube.com/watch?v=N4_-fB62Hwk): - -[![LoRA Part 1 Tutorial](https://img.youtube.com/vi/N4_-fB62Hwk/0.jpg)](https://www.youtube.com/watch?v=N4_-fB62Hwk) - -[How to Create a LoRA Part 2: Training the Model](https://www.youtube.com/watch?v=k5imq01uvUY): - -[![LoRA Part 2 Tutorial](https://img.youtube.com/vi/k5imq01uvUY/0.jpg)](https://www.youtube.com/watch?v=k5imq01uvUY) - -[**Generate Studio Quality Realistic Photos By Kohya LoRA Stable Diffusion Training - Full Tutorial**](https://youtu.be/TpuDOsuKIBo) - -[![image](https://cdn-uploads.huggingface.co/production/uploads/6345bd89fe134dfd7a0dba40/QA9woGfjeql37J9JepbrW.png)](https://youtu.be/TpuDOsuKIBo) - -[**First Ever SDXL Training With Kohya LoRA - Stable Diffusion XL Training Will Replace Older Models**](https://youtu.be/AY6DMBCIZ3A) - -[![image](https://cdn-uploads.huggingface.co/production/uploads/6345bd89fe134dfd7a0dba40/mG0CvKAzb8o29nr5ye0Br.png)](https://youtu.be/AY6DMBCIZ3A) - -[**Become A Master Of SDXL Training With Kohya SS LoRAs - Combine Power Of Automatic1111 & SDXL LoRAs**](https://youtu.be/sBFGitIvD2A) - -[![image](https://cdn-uploads.huggingface.co/production/uploads/6345bd89fe134dfd7a0dba40/rXbRquLxFaDGaGlkl-SUp.png)](https://youtu.be/sBFGitIvD2A) - -[**How To Do SDXL LoRA Training On RunPod With Kohya SS GUI Trainer & Use LoRAs With Automatic1111 UI**](https://youtu.be/-xEwaQ54DI4) - -[![image](https://cdn-uploads.huggingface.co/production/uploads/6345bd89fe134dfd7a0dba40/-BQQRjP9Maht_n4UHxgBJ.png)](https://youtu.be/-xEwaQ54DI4) - -[**How To Do SDXL LoRA Training On RunPod With Kohya SS GUI Trainer & Use LoRAs With Automatic1111 UI**](https://youtu.be/JF2P7BIUpIU?feature=shared) - -[![image](https://cdn-uploads.huggingface.co/production/uploads/6345bd89fe134dfd7a0dba40/n82kc7ND2rDmhRmRexLrb.png)](https://youtu.be/JF2P7BIUpIU?feature=shared) - -### About SDXL training - -The feature of SDXL training is now available in sdxl branch as an experimental feature. - -Sep 3, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version. - -- ControlNet-LLLite is added. See [documentation](./docs/train_lllite_README.md) for details. -- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786) -- Peak memory usage is reduced. [#791](https://github.com/kohya-ss/sd-scripts/pull/791) -- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details. -- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`. -- Other minor changes. -- Thanks for contributions from Isotr0py, vvern999, lansing and others! - -Aug 13, 2023: - -- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model. - -Aug 12, 2023: - -- The default value of noise offset when omitted has been changed to 0 from 0.0357. -- The different learning rates for each U-Net block are now supported. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. - - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. - -Aug 6, 2023: - -- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet. - - The main items are set automatically. - - You can set title, author, description, license and tags with `--metadata_xxx` options in each training script. - - Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage. - - Metadata editor will be available soon. -- SDXL LoRA has `sdxl_base_v1-0` now for `ss_base_model_version` metadata item, instead of `v0-9`. - -Aug 4, 2023: - -- `bitsandbytes` is now optional. Please install it if you want to use it. The instructions are in the later section. -- `albumentations` is not required any more. -- An issue for pooled output for Textual Inversion training is fixed. -- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled). - - In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps. -- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`. -- `sdxl_gen_imgs.py` supports batch size > 1. -- Fix ControlNet to work with attention couple and regional LoRA in `gen_img_diffusers.py`. - -Summary of the feature: - -- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. - - The options are almost the same as `sdxl_train.py'. See the help message for the usage. - - Please launch the script as follows: - `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` - - This script should work with multi-GPU, but it is not tested in my environment. - -- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. - - The options are almost the same as `cache_latents.py' and `sdxl_train.py'. See the help message for the usage. - -- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - - `--full_bf16` option is added. Thanks to KohakuBlueleaf! - - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. - - However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer. - - I cannot find bitsandbytes>0.35.0 that works correctly on Windows. - - In addition, the full bfloat16 training might be unstable. Please use it at your own risk. -- `prepare_buckets_latents.py` now supports SDXL fine-tuning. -- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. -- Both scripts has following additional options: - - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. - - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. -- The image generation during training is now available. `--no_half_vae` option also works to avoid black images. - -- `--weighted_captions` option is not supported yet for both scripts. -- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. - -- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. - - `--cache_text_encoder_outputs` is not supported. - - `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer. - - There are two options for captions: - 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. - 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. - - See below for the format of the embeddings. - -- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage. - - Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`. - -`requirements.txt` is updated to support SDXL training. - -#### Tips for SDXL training - -- The default resolution of SDXL is 1024x1024. -- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: - - Train U-Net only. - - Use gradient checkpointing. - - Use `--cache_text_encoder_outputs` option and caching latents. - - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. -- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: - - Train U-Net only. - - Use gradient checkpointing. - - Use `--cache_text_encoder_outputs` option and caching latents. - - Use one of 8bit optimizers or Adafactor optimizer. - - Use lower dim (-8 for 8GB GPU). -- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. -- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. -- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. - -Example of the optimizer settings for Adafactor with the fixed learning rate: - -```toml -optimizer_type = "adafactor" -optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] -lr_scheduler = "constant_with_warmup" -lr_warmup_steps = 100 -learning_rate = 4e-7 # SDXL original learning rate -``` + - [Jan 15, 2024 / 2024/1/15: v0.8.0](#jan-15-2024--2024115-v080) + - [Naming of LoRA](#naming-of-lora) + - [LoRAの名称について](#loraの名称について) + - [Sample image generation during training](#sample-image-generation-during-training-2) + - [Change History](#change-history-1) ## 🦒 Colab -🚦 WIP 🚦 - This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: https://github.com/camenduru/kohya_ss-colab. I would like to express my gratitude to camendutu for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository. @@ -627,7 +484,78 @@ save_file(state_dict, file) ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. +<<<<<<< HEAD ### Sample image generation during training +======= + +## Change History + +### Jan 15, 2024 / 2024/1/15: v0.8.0 + +- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). + - Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded. +- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev! + - This feature works only on Linux or WSL. + - Please specify `--torch_compile` option in each training script. + - You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work. + - Please use `--spda` option instead of `--xformers` option. + - PyTorch 2.1 or later is recommended. + - Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details. +- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t! +- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0! +- Fixed a bug that Diffusers format model cannot be saved. + +- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。 + - 最新の Transformers を前提とした一部のモデルファイル(Text Encoder が position_id を持たないもの)が読み込めるようになりました。 +- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します。 + - Linux または WSL でのみ動作します。 + - 各学習スクリプトで `--torch_compile` オプションを指定してください。 + - `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。 + - `--xformers` オプションとは互換性がありません。 代わりに `--spda` オプションを使用してください。 + - PyTorch 2.1以降を推奨します。 + - 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。 +- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。 +- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。 +- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました。 + + +Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. +最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 + +### Naming of LoRA + +The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository. + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers) + + LoRA for Linear layers and Conv2d layers with 1x1 kernel + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers) + + In addition to 1., LoRA for Conv2d layers with 3x3 kernel + +LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI. + +To use LoRA-C3Lier with Web UI, please use our extension. + +### LoRAの名称について + +`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。 + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) + + Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) + + 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA + +LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。 + +LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。 + +## Sample image generation during training +>>>>>>> 26d35794e3b858e7b5bd20d1e70547c378550b3d A prompt file might look like this, for example ``` @@ -651,6 +579,22 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b ## Change History +* 2024/01/15 (v22.5.0) +- Merged sd-scripts v0.8.0 updates + - Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). + - Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded. + - `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev! + - This feature works only on Linux or WSL. + - Please specify `--torch_compile` option in each training script. + - You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work. + - Please use `--spda` option instead of `--xformers` option. + - PyTorch 2.1 or later is recommended. + - Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details. + - The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t! + - IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0! + - Fixed a bug that Diffusers format model cannot be saved. +- Fix LoRA config display after load that would sometime hide some of the feilds + * 2024/01/02 (v22.4.1) - Minor bug fixed and enhancements. diff --git a/fine_tune.py b/fine_tune.py index f72e618b1..be61b3d16 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -291,6 +291,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) diff --git a/finetune_gui.py b/finetune_gui.py index c18355a06..7a9e77cd3 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -485,7 +485,7 @@ def train_model( # run_cmd += f' --flip_aug' if full_path: run_cmd += f' --full_path' - if sdxl_no_half_vae: + if sdxl_checkbox and sdxl_no_half_vae: log.info( 'Using mixed_precision = no because no half vae is selected...' ) @@ -584,11 +584,12 @@ def train_model( if int(max_token_length) > 75: run_cmd += f' --max_token_length={max_token_length}' - if sdxl_cache_text_encoder_outputs: - run_cmd += f' --cache_text_encoder_outputs' + if sdxl_checkbox: + if sdxl_cache_text_encoder_outputs: + run_cmd += f' --cache_text_encoder_outputs' - if sdxl_no_half_vae: - run_cmd += f' --no_half_vae' + if sdxl_no_half_vae: + run_cmd += f' --no_half_vae' run_cmd += run_cmd_training( learning_rate=learning_rate, diff --git a/gui.sh b/gui.sh index da7d4cdf2..022335125 100755 --- a/gui.sh +++ b/gui.sh @@ -72,7 +72,7 @@ fi #Set OneAPI if it's not set by the user if [[ "$@" == *"--use-ipex"* ]] then - if [ -d "$SCRIPT_DIR/venv" ]; then + if [ -d "$SCRIPT_DIR/venv" ] && [[ -z "${DISABLE_VENV_LIBS}" ]]; then export LD_LIBRARY_PATH=$(realpath "$SCRIPT_DIR/venv")/lib/:$LD_LIBRARY_PATH fi export NEOReadDebugKeys=1 @@ -82,7 +82,7 @@ then STARTUP_CMD=ipexrun if [[ -z "$STARTUP_CMD_ARGS" ]] then - STARTUP_CMD_ARGS="--multi-task-manager taskset --memory-allocator jemalloc" + STARTUP_CMD_ARGS="--multi-task-manager taskset --memory-allocator tcmalloc" fi fi fi diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index c78547915..333504935 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -140,6 +140,7 @@ def ipex_init(): # pylint: disable=too-many-statements # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 2e61f2c90..e98807a84 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,41 +1,98 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -original_torch_bmm = torch.bmm -def torch_bmm_32_bit(input, mat2, *, out=None): - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = input.element_size() - slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers + +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size block_size = batch_size_attention * slice_block_size split_slice_size = batch_size_attention - if block_size > 4: + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: do_split = True - # Find something divisible with the input_tokens - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - split_2_slice_size = input_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size do_split_2 = True - # Find something divisible with the input_tokens - while (split_2_slice_size * slice_block_size_2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - else: - do_split = False + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = input_tokens + split_3_slice_size = mat2_atten_shape + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + if input.device.type != "xpu": + return original_torch_bmm(input, mat2, out=out) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM if do_split: + batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -44,11 +101,21 @@ def torch_bmm_32_bit(input, mat2, *, out=None): for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) + if do_split_3: + for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + out=out + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) else: hidden_states[start_idx:end_idx] = original_torch_bmm( input[start_idx:end_idx], @@ -61,54 +128,13 @@ def torch_bmm_32_bit(input, mat2, *, out=None): original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - if len(query.shape) == 3: - batch_size_attention, query_tokens, shape_three = query.shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query.shape - - block_multiply = query.element_size() - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - if block_size > 4: - do_split = True - # Find something divisible with the batch_size_attention - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - split_2_slice_size = query_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply - do_split_2 = True - # Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size_2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - split_3_slice_size = shape_three - if split_2_slice_size * slice_block_size_2 > 4: - slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply - do_split_3 = True - # Find something divisible with the shape_three - while (split_3_slice_size * slice_block_size_3) > 4: - split_3_slice_size = split_3_slice_size // 2 - if split_3_slice_size <= 1: - split_3_slice_size = 1 - break - else: - do_split_3 = False - else: - do_split_2 = False - else: - do_split = False + if query.device.type != "xpu": + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + # Slice SDPA if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -145,7 +171,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo dropout_p=dropout_p, is_causal=is_causal ) else: - return original_scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal - ) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index c32af507b..47b0375ae 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,10 +1,62 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +@cache +def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + if slice_size is not None: + batch_size_attention = slice_size + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if query_device_type != "xpu": + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + residual = hidden_states input_ndim = hidden_states.ndim @@ -54,49 +108,61 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - block_multiply = query.element_size() - slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply - block_size = query_tokens * slice_block_size - split_2_slice_size = query_tokens - if block_size > 4: - do_split_2 = True - #Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size if do_split_2: for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - + del query_slice + del key_slice + del attn_mask_slice attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + #################################################################### hidden_states = attn.batch_to_head_dim(hidden_states) @@ -115,6 +181,130 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None, + temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) + + if do_split: + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + #################################################################### + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + def ipex_diffusers(): #ARC GPUs can't allocate more than 4GB to a single block: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index eb5f779f9..b6d246dd2 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,67 +1,9 @@ import contextlib -import importlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return -class CondFunc: # pylint: disable=missing-class-docstring - def __new__(cls, orig_func, sub_func, cond_func): - self = super(CondFunc, cls).__new__(cls) - if isinstance(orig_func, str): - func_path = orig_func.split('.') - for i in range(len(func_path)-1, -1, -1): - try: - resolved_obj = importlib.import_module('.'.join(func_path[:i])) - break - except ImportError: - pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) - self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) - -_utils = torch.utils.data._utils -def _shutdown_workers(self): - if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: - return - if hasattr(self, "_shutdown") and not self._shutdown: - self._shutdown = True - try: - if hasattr(self, '_pin_memory_thread'): - self._pin_memory_thread_done_event.set() - self._worker_result_queue.put((None, None)) - self._pin_memory_thread.join() - self._worker_result_queue.cancel_join_thread() - self._worker_result_queue.close() - self._workers_done_event.set() - for worker_id in range(len(self._workers)): - if self._persistent_workers or self._workers_status[worker_id]: - self._mark_worker_as_unavailable(worker_id, shutdown=True) - for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) - for q in self._index_queues: # pylint: disable=invalid-name - q.cancel_join_thread() - q.close() - finally: - if self._worker_pids_set: - torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) - self._worker_pids_set = False - for w in self._workers: # pylint: disable=invalid-name - if w.is_alive(): - w.terminate() - class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: @@ -71,17 +13,18 @@ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: def return_null_context(*args, **kwargs): # pylint: disable=unused-argument return contextlib.nullcontext() +@property +def is_cuda(self): + return self.device.type == 'xpu' or self.device.type == 'cuda' + def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" -def ipex_no_cuda(orig_func, *args, **kwargs): - torch.cuda.is_available = lambda: False - orig_func(*args, **kwargs) - torch.cuda.is_available = torch.xpu.is_available +# Autocast original_autocast = torch.autocast def ipex_autocast(*args, **kwargs): if len(args) > 0 and args[0] == "cuda": @@ -89,15 +32,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) -# Embedding BF16 -original_torch_cat = torch.cat -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) - -# Latent antialias: +# Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -109,19 +44,19 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) -original_linalg_solve = torch.linalg.solve -def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name - if A.device != torch.device("cpu") or B.device != torch.device("cpu"): - return_device = A.device - return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) +# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): +original_from_numpy = torch.from_numpy +def from_numpy(ndarray): + if ndarray.dtype == float: + return original_from_numpy(ndarray.astype('float32')) else: - return original_linalg_solve(A, B, *args, **kwargs) + return original_from_numpy(ndarray) if torch.xpu.has_fp64_dtype(): original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: - # 64 bit attention workarounds for Alchemist: + # 32 bit attention workarounds for Alchemist: try: from .attention import torch_bmm_32_bit as original_torch_bmm from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention @@ -129,7 +64,8 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -# dtype errors: + +# Data Type Errors: def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) @@ -142,111 +78,171 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. value = value.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) -@property -def is_cuda(self): - return self.device.type == 'xpu' +# A1111 FP16 +original_functional_group_norm = torch.nn.functional.group_norm +def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) + +# A1111 BF16 +original_functional_layer_norm = torch.nn.functional.layer_norm +def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps) + +# Training +original_functional_linear = torch.nn.functional.linear +def functional_linear(input, weight, bias=None): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_linear(input, weight, bias=bias) + +original_functional_conv2d = torch.nn.functional.conv2d +def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +# A1111 Embedding BF16 +original_torch_cat = torch.cat +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) -def ipex_hijacks(): - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.Tensor.to', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.Tensor.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.UntypedStorage.__init__', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.UntypedStorage.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.empty', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.randn', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.ones', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.zeros', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.linspace', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: - orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), - lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) - if hasattr(torch.xpu, "Generator"): - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - else: - CondFunc('torch.Generator', - lambda orig_func, device=None: orig_func(return_xpu(device)), - lambda orig_func, device=None: check_device(device)) - - # TiledVAE and ControlNet: - CondFunc('torch.batch_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - CondFunc('torch.instance_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - - # Functions with dtype errors: - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - # Training: - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - # BF16: - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) - # SwinIR BF16: - CondFunc('torch.nn.functional.pad', - lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), - lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - - # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): - if not torch.xpu.has_fp64_dtype(): - CondFunc('torch.from_numpy', - lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), - lambda orig_func, ndarray: ndarray.dtype == float) +# SwinIR BF16: +original_functional_pad = torch.nn.functional.pad +def functional_pad(input, pad, mode='constant', value=None): + if mode == 'reflect' and input.dtype == torch.bfloat16: + return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) + else: + return original_functional_pad(input, pad, mode=mode, value=value) - # Broken functions when torch.cuda.is_available is True: - # Pin Memory: - CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', - lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), - lambda orig_func, *args, **kwargs: True) - # Functions that make compile mad with CondFunc: - torch.nn.DataParallel = DummyDataParallel - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers +original_torch_tensor = torch.tensor +def torch_tensor(*args, device=None, **kwargs): + if check_device(device): + return original_torch_tensor(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_tensor(*args, device=device, **kwargs) + +original_Tensor_to = torch.Tensor.to +def Tensor_to(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_to(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_to(self, device, *args, **kwargs) + +original_Tensor_cuda = torch.Tensor.cuda +def Tensor_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_cuda(self, device, *args, **kwargs) + +original_UntypedStorage_init = torch.UntypedStorage.__init__ +def UntypedStorage_init(*args, device=None, **kwargs): + if check_device(device): + return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) + else: + return original_UntypedStorage_init(*args, device=device, **kwargs) + +original_UntypedStorage_cuda = torch.UntypedStorage.cuda +def UntypedStorage_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_UntypedStorage_cuda(self, device, *args, **kwargs) + +original_torch_empty = torch.empty +def torch_empty(*args, device=None, **kwargs): + if check_device(device): + return original_torch_empty(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_empty(*args, device=device, **kwargs) + +original_torch_randn = torch.randn +def torch_randn(*args, device=None, **kwargs): + if check_device(device): + return original_torch_randn(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_randn(*args, device=device, **kwargs) + +original_torch_ones = torch.ones +def torch_ones(*args, device=None, **kwargs): + if check_device(device): + return original_torch_ones(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_ones(*args, device=device, **kwargs) + +original_torch_zeros = torch.zeros +def torch_zeros(*args, device=None, **kwargs): + if check_device(device): + return original_torch_zeros(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_zeros(*args, device=device, **kwargs) + +original_torch_linspace = torch.linspace +def torch_linspace(*args, device=None, **kwargs): + if check_device(device): + return original_torch_linspace(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_linspace(*args, device=device, **kwargs) + +original_torch_Generator = torch.Generator +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +original_torch_load = torch.load +def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): + if check_device(map_location): + return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + else: + return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + +# Hijack Functions: +def ipex_hijacks(): + torch.tensor = torch_tensor + torch.Tensor.to = Tensor_to + torch.Tensor.cuda = Tensor_cuda + torch.UntypedStorage.__init__ = UntypedStorage_init + torch.UntypedStorage.cuda = UntypedStorage_cuda + torch.empty = torch_empty + torch.randn = torch_randn + torch.ones = torch_ones + torch.zeros = torch_zeros + torch.linspace = torch_linspace + torch.Generator = torch_Generator + torch.load = torch_load - torch.autocast = ipex_autocast torch.backends.cuda.sdp_kernel = return_null_context + torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda + torch.autocast = ipex_autocast + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + torch.nn.functional.group_norm = functional_group_norm + torch.nn.functional.layer_norm = functional_layer_norm + torch.nn.functional.linear = functional_linear + torch.nn.functional.conv2d = functional_conv2d torch.nn.functional.interpolate = interpolate - torch.linalg.solve = linalg_solve + torch.nn.functional.pad = functional_pad torch.bmm = torch_bmm torch.cat = torch_cat - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + if not torch.xpu.has_fp64_dtype(): + torch.from_numpy = from_numpy diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 4f9408352..91210129a 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -516,13 +516,13 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, + # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, - image_encoder: CLIPVisionModelWithProjection = None, # Incluindo o image_encoder requires_safety_checker: bool = True, + image_encoder: CLIPVisionModelWithProjection = None, clip_skip: int = 1, ): - self._clip_skip_internal = clip_skip super().__init__( vae=vae, text_encoder=text_encoder, @@ -531,47 +531,12 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, - image_encoder=image_encoder, requires_safety_checker=requires_safety_checker, + image_encoder=image_encoder, ) + self.custom_clip_skip = clip_skip self.__init__additional__() - @property - def clip_skip(self): - return self._clip_skip_internal - - @clip_skip.setter - def clip_skip(self, value): - self._clip_skip_internal = value - - def __setattr__(self, name: str, value): - if name == "clip_skip": - object.__setattr__(self, "_clip_skip_internal", value) - else: - super().__setattr__(name, value) - - # else: - # def __init__( - # self, - # vae: AutoencoderKL, - # text_encoder: CLIPTextModel, - # tokenizer: CLIPTokenizer, - # unet: UNet2DConditionModel, - # scheduler: SchedulerMixin, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - # ): - # super().__init__( - # vae=vae, - # text_encoder=text_encoder, - # tokenizer=tokenizer, - # unet=unet, - # scheduler=scheduler, - # safety_checker=safety_checker, - # feature_extractor=feature_extractor, - # ) - # self.__init__additional__() - def __init__additional__(self): if not hasattr(self, "vae_scale_factor"): setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) @@ -639,7 +604,7 @@ def _encode_prompt( prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, + clip_skip=self.custom_clip_skip, ) bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) @@ -1266,4 +1231,4 @@ def inpaint( callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, - ) + ) \ No newline at end of file diff --git a/library/model_util.py b/library/model_util.py index a577b97d4..1f40ce324 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,10 +4,13 @@ import math import os import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - # support checkpoint without position_ids (invalid checkpoint) - if "text_model.embeddings.position_ids" not in text_model_dict: - text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + # remove position_ids for newer transformer, which causes error :( + if "text_model.embeddings.position_ids" in text_model_dict: + text_model_dict.pop("text_model.embeddings.position_ids") return text_model_dict @@ -1242,8 +1245,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod if vae is None: vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + # original U-Net cannot be saved, so we need to convert it to the Diffusers version + # TODO this consumes a lot of memory + diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") + diffusers_unet.load_state_dict(unet.state_dict()) + pipeline = StableDiffusionPipeline( - unet=unet, + unet=diffusers_unet, text_encoder=text_encoder, vae=vae, scheduler=scheduler, diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index a844927cd..08b90c393 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -100,7 +100,7 @@ def convert_key(key): key = key.replace(".ln_final", ".final_layer_norm") # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids elif ".embeddings.position_ids" in key: - key = None # remove this key: make position_ids by ourselves + key = None # remove this key: position_ids is not used in newer transformers return key keys = list(checkpoint.keys()) @@ -126,10 +126,6 @@ def convert_key(key): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # original SD にはないので、position_idsを追加 - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids - # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) @@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - # 一部のposition_idsがないモデルへの対応 / add position_ids for some models - if "text_model.embeddings.position_ids" not in te1_sd: - te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) + # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers + if "text_model.embeddings.position_ids" in te1_sd: + te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) diff --git a/library/train_util.py b/library/train_util.py index 3c850019e..ff161feab 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2848,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) + parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", + # available backends: + # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 + # https://pytorch.org/docs/stable/torch.compiler.html + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -2935,6 +2946,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) parser.add_argument( "--log_tracker_config", type=str, @@ -3869,6 +3886,11 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + + # torch.compile のオプション。 NO の場合は torch.compile は使わない + dynamo_backend = "NO" + if args.torch_compile: + dynamo_backend = args.dynamo_backend kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, @@ -3883,6 +3905,7 @@ def prepare_accelerator(args: argparse.Namespace): log_with=log_with, project_dir=logging_dir, kwargs_handlers=kwargs_handlers, + dynamo_backend=dynamo_backend, ) return accelerator diff --git a/lora_gui.py b/lora_gui.py index fcfc3303a..1e403ff32 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -125,7 +125,8 @@ def save_configuration( caption_dropout_rate, optimizer, optimizer_args, - lr_scheduler_args,max_grad_norm, + lr_scheduler_args, + max_grad_norm, noise_offset_type, noise_offset, adaptive_noise_scale, @@ -133,7 +134,13 @@ def save_configuration( multires_noise_discount, LoRA_type, factor, - use_cp,use_tucker,use_scalar,rank_dropout_scale,constrain,rescaled,train_norm, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, decompose_both, train_on_input, conv_dim, @@ -280,7 +287,8 @@ def open_configuration( caption_dropout_rate, optimizer, optimizer_args, - lr_scheduler_args,max_grad_norm, + lr_scheduler_args, + max_grad_norm, noise_offset_type, noise_offset, adaptive_noise_scale, @@ -288,7 +296,13 @@ def open_configuration( multires_noise_discount, LoRA_type, factor, - use_cp,use_tucker,use_scalar,rank_dropout_scale,constrain,rescaled,train_norm, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, decompose_both, train_on_input, conv_dim, @@ -378,7 +392,18 @@ def open_configuration( values.append(json_value if json_value is not None else value) # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' - if my_data.get("LoRA_type", "Standard") == "LoCon": + if my_data.get("LoRA_type", "Standard") in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/LoCon", + "LyCORIS/GLoRA", + }: values.append(gr.Row.update(visible=True)) else: values.append(gr.Row.update(visible=False)) @@ -455,7 +480,8 @@ def train_model( caption_dropout_rate, optimizer, optimizer_args, - lr_scheduler_args,max_grad_norm, + lr_scheduler_args, + max_grad_norm, noise_offset_type, noise_offset, adaptive_noise_scale, @@ -463,7 +489,13 @@ def train_model( multires_noise_discount, LoRA_type, factor, - use_cp,use_tucker,use_scalar,rank_dropout_scale,constrain,rescaled,train_norm, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, decompose_both, train_on_input, conv_dim, @@ -825,9 +857,7 @@ def train_model( ) return run_cmd += f" --network_module=lycoris.kohya" - run_cmd += ( - f' --network_args "preset={LyCORIS_preset}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=full" "train_norm={train_norm}"' - ) + run_cmd += f' --network_args "preset={LyCORIS_preset}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=full" "train_norm={train_norm}"' # This is a hack to fix a train_network LoHA logic issue if not network_dropout > 0.0: run_cmd += f' --network_dropout="{network_dropout}"' @@ -974,11 +1004,12 @@ def train_model( if network_dropout > 0.0: run_cmd += f' --network_dropout="{network_dropout}"' - if sdxl_cache_text_encoder_outputs: - run_cmd += f" --cache_text_encoder_outputs" + if sdxl: + if sdxl_cache_text_encoder_outputs: + run_cmd += f" --cache_text_encoder_outputs" - if sdxl_no_half_vae: - run_cmd += f" --no_half_vae" + if sdxl_no_half_vae: + run_cmd += f" --no_half_vae" if full_bf16: run_cmd += f" --full_bf16" @@ -1503,18 +1534,28 @@ def update_LoRA_settings( "gr_type": gr.Slider, "update_params": { "maximum": 100000 - if LoRA_type in {"LyCORIS/LoHa", "LyCORIS/LoKr", "LyCORIS/Diag-OFT"} + if LoRA_type + in { + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Diag-OFT", + } else 512, - "value": 512, # if conv_dim > 512 else conv_dim, + "value": conv_dim, # if conv_dim > 512 else conv_dim, }, }, "network_dim": { "gr_type": gr.Slider, "update_params": { "maximum": 100000 - if LoRA_type in {"LyCORIS/LoHa", "LyCORIS/LoKr", "LyCORIS/Diag-OFT"} + if LoRA_type + in { + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Diag-OFT", + } else 512, - "value": 512, # if network_dim > 512 else network_dim, + "value": network_dim, # if network_dim > 512 else network_dim, }, }, "use_cp": { @@ -1789,7 +1830,13 @@ def update_LoRA_settings( factor, conv_dim, network_dim, - use_cp,use_tucker,use_scalar,rank_dropout_scale,constrain,rescaled,train_norm, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, decompose_both, train_on_input, scale_weight_norms, @@ -1912,7 +1959,13 @@ def update_LoRA_settings( advanced_training.multires_noise_discount, LoRA_type, factor, - use_cp,use_tucker,use_scalar,rank_dropout_scale,constrain,rescaled,train_norm, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, decompose_both, train_on_input, conv_dim, diff --git a/presets/lora/sd15 - LoKr v2.0.json b/presets/lora/sd15 - LoKr v2.0.json new file mode 100644 index 000000000..e65637dab --- /dev/null +++ b/presets/lora/sd15 - LoKr v2.0.json @@ -0,0 +1,107 @@ +{ + "LoRA_type": "LyCORIS/LoKr", + "LyCORIS_preset": "full", + "adaptive_noise_scale": 0, + "additional_parameters": "--lr_scheduler_type \"CosineAnnealingLR\" --lr_scheduler_args \"T_max=1000\" \"eta_min=0e-0\"", + "block_alphas": "", + "block_dims": "", + "block_lr_zero_threshold": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 1, + "cache_latents": true, + "cache_latents_to_disk": true, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.1, + "caption_extension": ".txt", + "clip_skip": "1", + "color_aug": false, + "constrain": 0.0, + "conv_alpha": 1, + "conv_block_alphas": "", + "conv_block_dims": "", + "conv_dim": 100000, + "debiased_estimation_loss": false, + "decompose_both": false, + "dim_from_weights": false, + "down_lr_weight": "", + "enable_bucket": true, + "epoch": 150, + "factor": 6, + "flip_aug": false, + "full_bf16": false, + "full_fp16": false, + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "keep_tokens": 1, + "learning_rate": 1.0, + "lora_network_weights": "", + "lr_scheduler": "cosine", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "lr_warmup": 0, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": "0", + "max_grad_norm": 1, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": "75", + "max_train_epochs": "", + "max_train_steps": "", + "mem_eff_attn": false, + "mid_lr_weight": "", + "min_bucket_reso": 256, + "min_snr_gamma": 5, + "min_timestep": 0, + "mixed_precision": "bf16", + "module_dropout": 0, + "multires_noise_discount": 0.1, + "multires_noise_iterations": 6, + "network_alpha": 1, + "network_dim": 100000, + "network_dropout": 0, + "no_token_padding": false, + "noise_offset": 0, + "noise_offset_type": "Multires", + "num_cpu_threads_per_process": 2, + "optimizer": "Prodigy", + "optimizer_args": "\"d0=1e-5\" \"d_coef=1.0\" \"weight_decay=0.4\" \"decouple=True\" \"safeguard_warmup=True\" \"use_bias_correction=True\"", + "persistent_data_loader_workers": false, + "prior_loss_weight": 1.0, + "random_crop": false, + "rank_dropout": 0, + "rank_dropout_scale": false, + "rescaled": false, + "save_every_n_epochs": 15, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_precision": "bf16", + "scale_v_pred_loss_like_noise_pred": false, + "scale_weight_norms": 0, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": true, + "seed": "", + "shuffle_caption": true, + "stop_text_encoder_training": 0, + "text_encoder_lr": 1.0, + "train_batch_size": 2, + "train_norm": false, + "train_on_input": false, + "training_comment": "KoopaTroopa", + "unet_lr": 1.0, + "unit": 1, + "up_lr_weight": "", + "use_cp": false, + "use_scalar": false, + "use_tucker": false, + "use_wandb": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "weighted_captions": false, + "xformers": "xformers" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5dc8a40ee..c2be38700 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ -accelerate==0.23.0 +accelerate==0.25.0 # albumentations==1.3.0 aiofiles==23.2.1 altair==4.2.2 dadaptation==3.1 -diffusers[torch]==0.24.0 +diffusers[torch]==0.25.0 easygui==0.98.3 -einops==0.6.0 +einops==0.6.1 fairscale==0.4.13 ftfy==6.1.1 gradio==3.36.1 -huggingface-hub==0.19.4 +huggingface-hub==0.20.1 # for loading Diffusers' SDXL invisible-watermark==0.2.0 lion-pytorch==0.0.6 @@ -38,7 +38,7 @@ safetensors==0.3.1 timm==0.6.12 tk==0.1.0 toml==0.10.2 -transformers==4.30.2 +transformers==4.36.2 voluptuous==0.13.1 wandb==0.15.11 scipy==1.11.4 diff --git a/sdxl_train.py b/sdxl_train.py index 8983673d2..b4ce2770e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -457,6 +457,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 18c6bd053..4436dd3cd 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -342,6 +342,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 1b68f9a35..7e9d7c7b9 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -601,7 +601,7 @@ def train_model( if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' - if sdxl_no_half_vae: + if sdxl and sdxl_no_half_vae: run_cmd += f' --no_half_vae' run_cmd += run_cmd_training( diff --git a/train_controlnet.py b/train_controlnet.py index 1f3dbae30..cc0eaab7a 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -336,6 +336,8 @@ def train(args): ) if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( diff --git a/train_db.py b/train_db.py index 5518740f1..14d9dff13 100644 --- a/train_db.py +++ b/train_db.py @@ -268,6 +268,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) diff --git a/train_network.py b/train_network.py index 9cba78da0..a75299cda 100644 --- a/train_network.py +++ b/train_network.py @@ -684,6 +684,8 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 877ac838e..0e3912b1d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -441,9 +441,10 @@ def train(self, args): # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) + unwrapped_text_encoder.text_model.encoder.requires_grad_(False) + unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False) + unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) unet.requires_grad_(False) @@ -503,6 +504,8 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -603,7 +606,7 @@ def remove_model(old_ckpt_name): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() + params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -615,9 +618,11 @@ def remove_model(old_ckpt_name): for text_encoder, orig_embeds_params, index_no_updates in zip( text_encoders, orig_embeds_params_list, index_no_updates_list ): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 + input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight + input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ index_no_updates - ] = orig_embeds_params[index_no_updates] + ] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 42d69d2de..71b43549d 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -394,6 +394,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)