diff --git a/.dockerignore b/.dockerignore index 7e9e5b444..9849d33f3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,6 +3,7 @@ cudnn_windows/ bitsandbytes_windows/ bitsandbytes_windows_deprecated/ dataset/ +models/ __pycache__/ venv/ **/.hadolint.yml diff --git a/.github/workflows/docker_publish.yml b/.github/workflows/docker_publish.yml index 520045d86..ac198d1cb 100644 --- a/.github/workflows/docker_publish.yml +++ b/.github/workflows/docker_publish.yml @@ -71,7 +71,7 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: Build and push - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 id: publish with: context: . diff --git a/.release b/.release index 9edcada1f..dc864a052 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v24.1.7 \ No newline at end of file +v24.2.0 diff --git a/README.md b/README.md index 45f4aed92..3fa142929 100644 --- a/README.md +++ b/README.md @@ -46,13 +46,19 @@ The GUI allows you to set the training parameters and generate and run the requi - [Potential Solutions](#potential-solutions) - [SDXL training](#sdxl-training) - [Masked loss](#masked-loss) + - [Guides](#guides) + - [Using Accelerate Lora Tab to Select GPU ID](#using-accelerate-lora-tab-to-select-gpu-id) + - [Starting Accelerate in GUI](#starting-accelerate-in-gui) + - [Running Multiple Instances (linux)](#running-multiple-instances-linux) + - [Monitoring Processes](#monitoring-processes) + - [Interesting Forks](#interesting-forks) - [Change History](#change-history) ## 🦒 Colab This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: . -I would like to express my gratitude to camendutu for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository. +I would like to express my gratitude to camenduru for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository. | Colab | Info | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ | @@ -69,7 +75,7 @@ To install the necessary dependencies on a Windows system, follow these steps: 1. Install [Python 3.10.11](https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe). - During the installation process, ensure that you select the option to add Python to the 'PATH' environment variable. -2. Install [CUDA 11.8 toolkit](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows&target_arch=x86_64). +2. Install [CUDA 12.4 toolkit](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Windows&target_arch=x86_64). 3. Install [Git](https://git-scm.com/download/win). @@ -127,7 +133,7 @@ To install the necessary dependencies on a Linux system, ensure that you fulfill apt install python3.10-venv ``` -- Install the CUDA 11.8 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64). +- Install the CUDA 12.4 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Linux&target_arch=x86_64). - Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system. @@ -329,13 +335,27 @@ To upgrade your installation on Linux or macOS, follow these steps: To launch the GUI service, you can use the provided scripts or run the `kohya_gui.py` script directly. Use the command line arguments listed below to configure the underlying service. ```text ---listen: Specify the IP address to listen on for connections to Gradio. ---username: Set a username for authentication. ---password: Set a password for authentication. ---server_port: Define the port to run the server listener on. ---inbrowser: Open the Gradio UI in a web browser. ---share: Share the Gradio UI. ---language: Set custom language + --help show this help message and exit + --config CONFIG Path to the toml config file for interface defaults + --debug Debug on + --listen LISTEN IP to listen on for connections to Gradio + --username USERNAME Username for authentication + --password PASSWORD Password for authentication + --server_port SERVER_PORT + Port to run the server listener on + --inbrowser Open in browser + --share Share the gradio UI + --headless Is the server headless + --language LANGUAGE Set custom language + --use-ipex Use IPEX environment + --use-rocm Use ROCm environment + --do_not_use_shell Enforce not to use shell=True when running external commands + --do_not_share Do not share the gradio UI + --requirements REQUIREMENTS + requirements file to use for validation + --root_path ROOT_PATH + `root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss + --noverify Disable requirements verification ``` ### Launching the GUI on Windows @@ -438,6 +458,37 @@ The feature is not fully tested, so there may be bugs. If you find any issues, p ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). +## Guides + +The following are guides extracted from issues discussions + +### Using Accelerate Lora Tab to Select GPU ID + +#### Starting Accelerate in GUI + +- Open the kohya GUI on your desired port. +- Open the `Accelerate launch` tab +- Ensure the Multi-GPU checkbox is unchecked. +- Set GPU IDs to the desired GPU (like 1). + +#### Running Multiple Instances (linux) + +- For tracking multiple processes, use separate kohya GUI instances on different ports (e.g., 7860, 7861). +- Start instances using `nohup ./gui.sh --listen 0.0.0.0 --server_port --headless > log.log 2>&1 &`. + +#### Monitoring Processes + +- Open each GUI in a separate browser tab. +- For terminal access, use SSH and tools like `tmux` or `screen`. + +For more details, visit the [GitHub issue](https://github.com/bmaltais/kohya_ss/issues/2577). + +## Interesting Forks + +To finetune HunyuanDiT models or create LoRAs, visit this [fork](https://github.com/Tencent/HunyuanDiT/tree/main/kohya_ss-hydit) + ## Change History -See release information. +Added support for SD3 (Dreambooth and Finetuning) and Flux.1 (Dreambooth, LoRA and Finetuning). + +See for more details. diff --git a/_typos.toml b/_typos.toml index d73875a92..28ddf851f 100644 --- a/_typos.toml +++ b/_typos.toml @@ -9,6 +9,7 @@ parms="parms" nin="nin" extention="extention" # Intentionally left nd="nd" +pn="pn" shs="shs" sts="sts" scs="scs" diff --git a/assets/style.css b/assets/style.css index 939ac937f..f8cfe112b 100644 --- a/assets/style.css +++ b/assets/style.css @@ -1,4 +1,4 @@ -#open_folder_small{ +#open_folder_small { min-width: auto; flex-grow: 0; padding-left: 0.25em; @@ -7,14 +7,14 @@ font-size: 1.5em; } -#open_folder{ +#open_folder { height: auto; flex-grow: 0; padding-left: 0.25em; padding-right: 0.25em; } -#number_input{ +#number_input { min-width: min-content; flex-grow: 0.3; padding-left: 0.75em; @@ -22,7 +22,7 @@ } .ver-class { - color: #808080; + color: #6d6d6d; /* Neutral dark gray */ font-size: small; text-align: right; padding-right: 1em; @@ -35,13 +35,212 @@ } #myTensorButton { - background: radial-gradient(ellipse, #3a99ff, #52c8ff); + background: #555c66; /* Muted dark gray */ color: white; - border: #296eb8; + border: none; + border-radius: 4px; + padding: 0.5em 1em; + /* box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); Subtle shadow */ + /* transition: box-shadow 0.3s ease; */ +} + +#myTensorButton:hover { + /* box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); Slightly increased shadow on hover */ } #myTensorButtonStop { - background: radial-gradient(ellipse, #52c8ff, #3a99ff); - color: black; - border: #296eb8; + background: #777d85; /* Lighter muted gray */ + color: white; + border: none; + border-radius: 4px; + padding: 0.5em 1em; + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); + /* transition: box-shadow 0.3s ease; */ +} + +#myTensorButtonStop:hover { + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); +} + +.advanced_background { + background: #f4f4f4; /* Light neutral gray */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */ +} + +.advanced_background:hover { + background-color: #ebebeb; /* Slightly darker background on hover */ + border: 1px solid #ccc; /* Add a subtle border */ + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.basic_background { + background: #eaeff1; /* Muted cool gray */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.basic_background:hover { + background-color: #dfe4e7; /* Slightly darker cool gray on hover */ + border: 1px solid #ccc; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.huggingface_background { + background: #e0e4e7; /* Light gray with a hint of blue */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.huggingface_background:hover { + background-color: #d6dce0; /* Slightly darker on hover */ + border: 1px solid #bbb; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.flux1_background { + background: #ece9e6; /* Light beige tone */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.flux1_background:hover { + background-color: #e2dfdb; /* Slightly darker beige on hover */ + border: 1px solid #ccc; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.preset_background { + background: #f0f0f0; /* Light gray */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.preset_background:hover { + background-color: #e6e6e6; /* Slightly darker on hover */ + border: 1px solid #ccc; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.samples_background { + background: #d9dde1; /* Soft muted gray-blue */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.samples_background:hover { + background-color: #cfd3d8; /* Slightly darker on hover */ + border: 1px solid #bbb; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +/* Dark mode styles */ +.dark .advanced_background { + background: #172029; /* Slightly darker gradio dark theme */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */ +} + +.dark .advanced_background:hover { + background-color: #121920; /* Slightly darker background on hover */ + border: 1px solid #000000; /* Add a subtle border */ + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .basic_background { + background: #172029; + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .basic_background:hover { + background-color: #11181e; + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .huggingface_background { + background: #131c25; + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .huggingface_background:hover { + background-color: #131c25; + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .flux1_background { + background: #131c25; + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .flux1_background:hover { + background-color: #131c25; + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .preset_background { + background: #191d25; + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .preset_background:hover { + background-color: #212530; + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .samples_background { + background: #101e2c; + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .samples_background:hover { + background-color: #17293a; + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.flux1_rank_layers_background { + background: #ece9e6; /* White background for clear theme */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.flux1_rank_layers_background:hover { + background-color: #dddad7; /* Slightly darker on hover */ + border: 1px solid #ccc; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .flux1_rank_layers_background { + background: #131c25; /* Dark background for dark theme */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .flux1_rank_layers_background:hover { + background-color: #131c25; /* Slightly darker on hover */ + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ } \ No newline at end of file diff --git a/config example.toml b/config example.toml index c137d4387..fff12555b 100644 --- a/config example.toml +++ b/config example.toml @@ -48,6 +48,7 @@ learning_rate_te1 = 0.0001 # Learning rate text encoder 1 learning_rate_te2 = 0.0001 # Learning rate text encoder 2 lr_scheduler = "cosine" # LR Scheduler lr_scheduler_args = "" # LR Scheduler args +lr_scheduler_type = "" # LR Scheduler type lr_warmup = 0 # LR Warmup (% of total steps) lr_scheduler_num_cycles = 1 # LR Scheduler num cycles lr_scheduler_power = 1.0 # LR Scheduler power @@ -150,6 +151,9 @@ sample_prompts = "" # Sample prompts sample_sampler = "euler_a" # Sampler to use for image sampling [sdxl] +disable_mmap_load_safetensors = false # Disable mmap load safe tensors +fused_backward_pass = false # Fused backward pass +fused_optimizer_groups = 0 # Fused optimizer groups sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs sdxl_no_half_vae = true # No half VAE diff --git a/docker-compose.yaml b/docker-compose.yaml index 4932bcee2..cadffb0fa 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -20,11 +20,13 @@ services: - /tmp volumes: - /tmp/.X11-unix:/tmp/.X11-unix + - ./models:/app/models - ./dataset:/dataset - ./dataset/images:/app/data - ./dataset/logs:/app/logs - ./dataset/outputs:/app/outputs - ./dataset/regularization:/app/regularization + - ./models:/app/models - ./.cache/config:/app/config - ./.cache/user:/home/1000/.cache - ./.cache/triton:/home/1000/.triton diff --git a/examples/pull kohya_ss sd-scripts updates in.md b/examples/pull kohya_ss sd-scripts updates in.md index 47b1c79ad..6c94fb18b 100644 --- a/examples/pull kohya_ss sd-scripts updates in.md +++ b/examples/pull kohya_ss sd-scripts updates in.md @@ -1,32 +1,27 @@ -## Updating a Local Branch with the Latest sd-scripts Changes +## Updating a Local Submodule with the Latest sd-scripts Changes To update your local branch with the most recent changes from kohya/sd-scripts, follow these steps: -1. Add sd-scripts as an alternative remote by executing the following command: +1. When you wish to perform an update of the dev branch, execute the following commands: - ``` - git remote add sd-scripts https://github.com/kohya-ss/sd-scripts.git - ``` - -2. When you wish to perform an update, execute the following commands: - - ``` - git checkout dev - git pull sd-scripts main - ``` - - Alternatively, if you want to obtain the latest code, even if it may be unstable: - - ``` + ```bash + cd sd-scripts + git fetch git checkout dev - git pull sd-scripts dev + git pull origin dev + cd .. + git add sd-scripts + git commit -m "Update sd-scripts submodule to the latest on dev" ``` -3. If you encounter a conflict with the Readme file, you can resolve it by taking the following steps: + Alternatively, if you want to obtain the latest code from main: + ```bash + cd sd-scripts + git fetch + git checkout main + git pull origin main + cd .. + git add sd-scripts + git commit -m "Update sd-scripts submodule to the latest on main" ``` - git add README.md - git merge --continue - ``` - - This may open a text editor for a commit message, but you can simply save and close it to proceed. Following these steps should resolve the conflict. If you encounter additional merge conflicts, consider them as valuable learning opportunities for personal growth. \ No newline at end of file diff --git a/gui.bat b/gui.bat index e5e206db0..74034b9c5 100644 --- a/gui.bat +++ b/gui.bat @@ -7,11 +7,13 @@ call .\venv\Scripts\deactivate.bat :: Activate the virtual environment call .\venv\Scripts\activate.bat + +:: Update pip to latest version +python -m pip install --upgrade pip -q + set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib -:: Validate requirements -python.exe .\setup\validate_requirements.py -if %errorlevel% neq 0 exit /b %errorlevel% +echo Starting the GUI... this might take some time... :: If the exit code is 0, run the kohya_gui.py script with the command-line arguments if %errorlevel% equ 0 ( diff --git a/gui.ps1 b/gui.ps1 index 9e9a441de..47e69aca5 100644 --- a/gui.ps1 +++ b/gui.ps1 @@ -7,28 +7,18 @@ if ($env:VIRTUAL_ENV) { # Activate the virtual environment # Write-Host "Activating the virtual environment..." & .\venv\Scripts\activate -$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib" -# Debug info about system -# python.exe .\setup\debug_info.py +python.exe -m pip install --upgrade pip -q + +$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib" -# Validate the requirements and store the exit code -python.exe .\setup\validate_requirements.py +Write-Host "Starting the GUI... this might take some time..." -# Check the exit code and stop execution if it is not 0 -if ($LASTEXITCODE -ne 0) { - Write-Host "Failed to validate requirements. Exiting script..." - exit $LASTEXITCODE +$argsFromFile = @() +if (Test-Path .\gui_parameters.txt) { + $argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " } } +$args_combo = $argsFromFile + $args +# Write-Host "The arguments passed to this script were: $args_combo" +python.exe kohya_gui.py $args_combo -# If the exit code is 0, read arguments from gui_parameters.txt (if it exists) -# and run the kohya_gui.py script with the command-line arguments -if ($LASTEXITCODE -eq 0) { - $argsFromFile = @() - if (Test-Path .\gui_parameters.txt) { - $argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " } - } - $args_combo = $argsFromFile + $args - # Write-Host "The arguments passed to this script were: $args_combo" - python.exe kohya_gui.py $args_combo -} diff --git a/gui.sh b/gui.sh index 17c5207a4..c6502d3ec 100755 --- a/gui.sh +++ b/gui.sh @@ -111,10 +111,4 @@ then STARTUP_CMD=python fi -# Validate the requirements and run the script if successful -if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then - "${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "$@" -else - echo "Validation failed. Exiting..." - exit 1 -fi +"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "--requirements=""$REQUIREMENTS_FILE" "$@" diff --git a/kohya_gui.py b/kohya_gui.py index f586f0a29..e8d64a884 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -1,6 +1,10 @@ -import gradio as gr import os +import sys import argparse +import subprocess +import contextlib +import gradio as gr + from kohya_gui.class_gui_config import KohyaSSGUIConfig from kohya_gui.dreambooth_gui import dreambooth_tab from kohya_gui.finetune_gui import finetune_tab @@ -8,71 +12,43 @@ from kohya_gui.utilities import utilities_tab from kohya_gui.lora_gui import lora_tab from kohya_gui.class_lora_tab import LoRATools - from kohya_gui.custom_logging import setup_logging from kohya_gui.localization_ext import add_javascript - -def UI(**kwargs): - add_javascript(kwargs.get("language")) - css = "" - - headless = kwargs.get("headless", False) - log.info(f"headless: {headless}") - - if os.path.exists("./assets/style.css"): - with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: - log.debug("Load CSS...") - css += file.read() + "\n" - - if os.path.exists("./.release"): - with open(os.path.join("./.release"), "r", encoding="utf8") as file: - release = file.read() - - if os.path.exists("./README.md"): - with open(os.path.join("./README.md"), "r", encoding="utf8") as file: - README = file.read() - - interface = gr.Blocks( - css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default() - ) - - config = KohyaSSGUIConfig(config_file_path=kwargs.get("config")) - - if config.is_config_loaded(): - log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...") - - use_shell_flag = True - # if os.name == "posix": - # use_shell_flag = True - - use_shell_flag = config.get("settings.use_shell", use_shell_flag) - - if kwargs.get("do_not_use_shell", False): - use_shell_flag = False - - if use_shell_flag: - log.info("Using shell=True when running external commands...") - - with interface: +PYTHON = sys.executable +project_dir = os.path.dirname(os.path.abspath(__file__)) + +# Function to read file content, suppressing any FileNotFoundError +def read_file_content(file_path): + with contextlib.suppress(FileNotFoundError): + with open(file_path, "r", encoding="utf8") as file: + return file.read() + return "" + +# Function to initialize the Gradio UI interface +def initialize_ui_interface(config, headless, use_shell, release_info, readme_content): + # Load custom CSS if available + css = read_file_content("./assets/style.css") + + # Create the main Gradio Blocks interface + ui_interface = gr.Blocks(css=css, title=f"Kohya_ss GUI {release_info}", theme=gr.themes.Default()) + with ui_interface: + # Create tabs for different functionalities with gr.Tab("Dreambooth"): ( train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input, - ) = dreambooth_tab( - headless=headless, config=config, use_shell_flag=use_shell_flag - ) + ) = dreambooth_tab(headless=headless, config=config, use_shell_flag=use_shell) with gr.Tab("LoRA"): - lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) + lora_tab(headless=headless, config=config, use_shell_flag=use_shell) with gr.Tab("Textual Inversion"): - ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) + ti_tab(headless=headless, config=config, use_shell_flag=use_shell) with gr.Tab("Finetuning"): - finetune_tab( - headless=headless, config=config, use_shell_flag=use_shell_flag - ) + finetune_tab(headless=headless, config=config, use_shell_flag=use_shell) with gr.Tab("Utilities"): + # Utilities tab requires inputs from the Dreambooth tab utilities_tab( train_data_dir_input=train_data_dir_input, reg_data_dir_input=reg_data_dir_input, @@ -84,102 +60,97 @@ def UI(**kwargs): with gr.Tab("LoRA"): _ = LoRATools(headless=headless) with gr.Tab("About"): - gr.Markdown(f"kohya_ss GUI release {release}") + # About tab to display release information and README content + gr.Markdown(f"kohya_ss GUI release {release_info}") with gr.Tab("README"): - gr.Markdown(README) - - htmlStr = f""" - - -
{release}
- - - """ - gr.HTML(htmlStr) - # Show the interface - launch_kwargs = {} - username = kwargs.get("username") - password = kwargs.get("password") - server_port = kwargs.get("server_port", 0) - inbrowser = kwargs.get("inbrowser", False) - share = kwargs.get("share", False) - do_not_share = kwargs.get("do_not_share", False) - server_name = kwargs.get("listen") - root_path = kwargs.get("root_path", None) - - launch_kwargs["server_name"] = server_name - if username and password: - launch_kwargs["auth"] = (username, password) - if server_port > 0: - launch_kwargs["server_port"] = server_port - if inbrowser: - launch_kwargs["inbrowser"] = inbrowser - if do_not_share: - launch_kwargs["share"] = False - else: - if share: - launch_kwargs["share"] = share - if root_path: - launch_kwargs["root_path"] = root_path - launch_kwargs["debug"] = True - interface.launch(**launch_kwargs) + gr.Markdown(readme_content) + # Display release information in a div element + gr.Markdown(f"
{release_info}
") -if __name__ == "__main__": - # torch.cuda.set_per_process_memory_fraction(0.48) + return ui_interface + +# Function to configure and launch the UI +def UI(**kwargs): + # Add custom JavaScript if specified + add_javascript(kwargs.get("language")) + log.info(f"headless: {kwargs.get('headless', False)}") + + # Load release and README information + release_info = read_file_content("./.release") + readme_content = read_file_content("./README.md") + + # Load configuration from the specified file + config = KohyaSSGUIConfig(config_file_path=kwargs.get("config")) + if config.is_config_loaded(): + log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...") + + # Determine if shell should be used for running external commands + use_shell = not kwargs.get("do_not_use_shell", False) and config.get("settings.use_shell", True) + if use_shell: + log.info("Using shell=True when running external commands...") + + # Initialize the Gradio UI interface + ui_interface = initialize_ui_interface(config, kwargs.get("headless", False), use_shell, release_info, readme_content) + + # Construct launch parameters using dictionary comprehension + launch_params = { + "server_name": kwargs.get("listen"), + "auth": (kwargs["username"], kwargs["password"]) if kwargs.get("username") and kwargs.get("password") else None, + "server_port": kwargs.get("server_port", 0) if kwargs.get("server_port", 0) > 0 else None, + "inbrowser": kwargs.get("inbrowser", False), + "share": False if kwargs.get("do_not_share", False) else kwargs.get("share", False), + "root_path": kwargs.get("root_path", None), + "debug": kwargs.get("debug", False), + } + + # This line filters out any key-value pairs from `launch_params` where the value is `None`, ensuring only valid parameters are passed to the `launch` function. + launch_params = {k: v for k, v in launch_params.items() if v is not None} + + # Launch the Gradio interface with the specified parameters + ui_interface.launch(**launch_params) + +# Function to initialize argument parser for command-line arguments +def initialize_arg_parser(): parser = argparse.ArgumentParser() - parser.add_argument( - "--config", - type=str, - default="./config.toml", - help="Path to the toml config file for interface defaults", - ) + parser.add_argument("--config", type=str, default="./config.toml", help="Path to the toml config file for interface defaults") parser.add_argument("--debug", action="store_true", help="Debug on") - parser.add_argument( - "--listen", - type=str, - default="127.0.0.1", - help="IP to listen on for connections to Gradio", - ) - parser.add_argument( - "--username", type=str, default="", help="Username for authentication" - ) - parser.add_argument( - "--password", type=str, default="", help="Password for authentication" - ) - parser.add_argument( - "--server_port", - type=int, - default=0, - help="Port to run the server listener on", - ) + parser.add_argument("--listen", type=str, default="127.0.0.1", help="IP to listen on for connections to Gradio") + parser.add_argument("--username", type=str, default="", help="Username for authentication") + parser.add_argument("--password", type=str, default="", help="Password for authentication") + parser.add_argument("--server_port", type=int, default=0, help="Port to run the server listener on") parser.add_argument("--inbrowser", action="store_true", help="Open in browser") parser.add_argument("--share", action="store_true", help="Share the gradio UI") - parser.add_argument( - "--headless", action="store_true", help="Is the server headless" - ) - parser.add_argument( - "--language", type=str, default=None, help="Set custom language" - ) - + parser.add_argument("--headless", action="store_true", help="Is the server headless") + parser.add_argument("--language", type=str, default=None, help="Set custom language") parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment") parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment") + parser.add_argument("--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands") + parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI") + parser.add_argument("--requirements", type=str, default=None, help="requirements file to use for validation") + parser.add_argument("--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss") + parser.add_argument("--noverify", action="store_true", help="Disable requirements verification") + return parser - parser.add_argument( - "--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands" - ) - - parser.add_argument( - "--do_not_share", action="store_true", help="Do not share the gradio UI" - ) - - parser.add_argument( - "--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss" - ) - +if __name__ == "__main__": + # Initialize argument parser and parse arguments + parser = initialize_arg_parser() args = parser.parse_args() - # Set up logging + # Set up logging based on the debug flag log = setup_logging(debug=args.debug) - UI(**vars(args)) + # Verify requirements unless `noverify` flag is set + if args.noverify: + log.warning("Skipping requirements verification.") + else: + # Run the validation command to verify requirements + validation_command = [PYTHON, os.path.join(project_dir, "setup", "validate_requirements.py")] + + if args.requirements is not None: + validation_command.append(f"--requirements={args.requirements}") + + subprocess.run(validation_command, check=True) + + # Launch the UI with the provided arguments + UI(**vars(args)) \ No newline at end of file diff --git a/kohya_gui/basic_caption_gui.py b/kohya_gui/basic_caption_gui.py index d352954a1..ee834a39c 100644 --- a/kohya_gui/basic_caption_gui.py +++ b/kohya_gui/basic_caption_gui.py @@ -102,7 +102,7 @@ def caption_images( postfix=postfix, ) # Replace specified text in caption files if find and replace text is provided - if find_text and replace_text: + if find_text: find_replace( folder_path=images_dir, caption_file_ext=caption_ext, diff --git a/kohya_gui/blip2_caption_gui.py b/kohya_gui/blip2_caption_gui.py index 5429db0b6..b3263227d 100644 --- a/kohya_gui/blip2_caption_gui.py +++ b/kohya_gui/blip2_caption_gui.py @@ -42,7 +42,7 @@ def get_images_in_directory(directory_path): import os # List of common image file extensions to look for - image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"] + image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"] # Generate a list of image file paths in the directory image_files = [ diff --git a/kohya_gui/class_accelerate_launch.py b/kohya_gui/class_accelerate_launch.py index 912337bdd..6cf70fb92 100644 --- a/kohya_gui/class_accelerate_launch.py +++ b/kohya_gui/class_accelerate_launch.py @@ -3,6 +3,10 @@ import shlex from .class_gui_config import KohyaSSGUIConfig +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() class AccelerateLaunch: @@ -79,12 +83,16 @@ def __init__( ) self.dynamo_use_fullgraph = gr.Checkbox( label="Dynamo use fullgraph", - value=self.config.get("accelerate_launch.dynamo_use_fullgraph", False), + value=self.config.get( + "accelerate_launch.dynamo_use_fullgraph", False + ), info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs", ) self.dynamo_use_dynamic = gr.Checkbox( label="Dynamo use dynamic", - value=self.config.get("accelerate_launch.dynamo_use_dynamic", False), + value=self.config.get( + "accelerate_launch.dynamo_use_dynamic", False + ), info="Whether to enable dynamic shape tracing.", ) @@ -103,6 +111,24 @@ def __init__( placeholder="example: 0,1", info=" What GPUs (by id) should be used for training on this machine as a comma-separated list", ) + + def validate_gpu_ids(value): + if value == "": + return + if not ( + value.isdigit() and int(value) >= 0 and int(value) <= 128 + ): + log.error("GPU IDs must be an integer between 0 and 128") + return + else: + for id in value.split(","): + if not id.isdigit() or int(id) < 0 or int(id) > 128: + log.error( + "GPU IDs must be an integer between 0 and 128" + ) + + self.gpu_ids.blur(fn=validate_gpu_ids, inputs=self.gpu_ids) + self.main_process_port = gr.Number( label="Main process port", value=self.config.get("accelerate_launch.main_process_port", 0), @@ -136,9 +162,14 @@ def run_cmd(run_cmd: list, **kwargs): if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"): run_cmd.append("--dynamo_use_dynamic") - - if "extra_accelerate_launch_args" in kwargs and kwargs["extra_accelerate_launch_args"] != "": - extra_accelerate_launch_args = kwargs["extra_accelerate_launch_args"].replace('"', "") + + if ( + "extra_accelerate_launch_args" in kwargs + and kwargs["extra_accelerate_launch_args"] != "" + ): + extra_accelerate_launch_args = kwargs[ + "extra_accelerate_launch_args" + ].replace('"', "") for arg in extra_accelerate_launch_args.split(): run_cmd.append(shlex.quote(arg)) diff --git a/kohya_gui/class_advanced_training.py b/kohya_gui/class_advanced_training.py index c9784c304..0aa9e0429 100644 --- a/kohya_gui/class_advanced_training.py +++ b/kohya_gui/class_advanced_training.py @@ -146,7 +146,7 @@ def list_vae_files(path): with gr.Row(): self.loss_type = gr.Dropdown( label="Loss type", - choices=["huber", "smooth_l1", "l2"], + choices=["huber", "smooth_l1", "l1", "l2"], value=self.config.get("advanced.loss_type", "l2"), info="The type of loss to use and whether it's scheduled based on the timestep", ) @@ -188,6 +188,18 @@ def list_vae_files(path): precision=0, info="(Optional) Save only the specified number of states (old models will be deleted)", ) + self.save_last_n_epochs = gr.Number( + label="Save last N epochs", + value=self.config.get("advanced.save_last_n_epochs", 0), + precision=0, + info="(Optional) Save only the specified number of epochs (old epochs will be deleted)", + ) + self.save_last_n_epochs_state = gr.Number( + label="Save last N epochs state", + value=self.config.get("advanced.save_last_n_epochs_state", 0), + precision=0, + info="(Optional) Save only the specified number of epochs states (old models will be deleted)", + ) with gr.Row(): def full_options_update(full_fp16, full_bf16): @@ -228,12 +240,16 @@ def full_options_update(full_fp16, full_bf16): ) with gr.Row(): - if training_type == "lora": - self.fp8_base = gr.Checkbox( - label="fp8 base training (experimental)", - info="U-Net and Text Encoder can be trained with fp8 (experimental)", - value=self.config.get("advanced.fp8_base", False), - ) + self.fp8_base = gr.Checkbox( + label="fp8 base", + info="Use fp8 for base model", + value=self.config.get("advanced.fp8_base", False), + ) + self.fp8_base_unet = gr.Checkbox( + label="fp8 base unet", + info="Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16.", + value=self.config.get("advanced.fp8_base_unet", False), + ) self.full_fp16 = gr.Checkbox( label="Full fp16 training (experimental)", value=self.config.get("advanced.full_fp16", False), @@ -254,6 +270,25 @@ def full_options_update(full_fp16, full_bf16): inputs=[self.full_fp16, self.full_bf16], outputs=[self.full_fp16, self.full_bf16], ) + + with gr.Row(): + self.highvram = gr.Checkbox( + label="highvram", + value=self.config.get("advanced.highvram", False), + info="Disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM)", + interactive=True, + ) + self.lowvram = gr.Checkbox( + label="lowvram", + value=self.config.get("advanced.lowvram", False), + info="Enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)", + interactive=True, + ) + self.skip_cache_check = gr.Checkbox( + label="Skip cache check", + value=self.config.get("advanced.skip_cache_check", False), + info="Skip cache check for faster training start", + ) with gr.Row(): self.gradient_checkpointing = gr.Checkbox( @@ -534,6 +569,11 @@ def list_log_tracker_config_files(path): self.current_log_tracker_config_dir = path if not path == "" else "." return list(list_files(path, exts=[".json"], all=True)) + self.log_config = gr.Checkbox( + label="Log config", + value=self.config.get("advanced.log_config", False), + info="Log training parameter to WANDB", + ) self.log_tracker_name = gr.Textbox( label="Log tracker name", value=self.config.get("advanced.log_tracker_name", ""), diff --git a/kohya_gui/class_basic_training.py b/kohya_gui/class_basic_training.py index ee22d552b..0d03769cf 100644 --- a/kohya_gui/class_basic_training.py +++ b/kohya_gui/class_basic_training.py @@ -25,6 +25,7 @@ def __init__( learning_rate_value: float = "1e-6", lr_scheduler_value: str = "constant", lr_warmup_value: float = "0", + lr_warmup_steps_value: int = 0, finetuning: bool = False, dreambooth: bool = False, config: dict = {}, @@ -44,10 +45,14 @@ def __init__( self.learning_rate_value = learning_rate_value self.lr_scheduler_value = lr_scheduler_value self.lr_warmup_value = lr_warmup_value + self.lr_warmup_steps_value= lr_warmup_steps_value self.finetuning = finetuning self.dreambooth = dreambooth self.config = config + + # Initialize old_lr_warmup and old_lr_warmup_steps with default values self.old_lr_warmup = 0 + self.old_lr_warmup_steps = 0 # Initialize the UI components self.initialize_ui_components() @@ -162,20 +167,37 @@ def init_lr_and_optimizer_controls(self) -> None: "cosine", "cosine_with_restarts", "linear", + "piecewise_constant", "polynomial", + "COSINE_WITH_MIN_LR", + "INVERSE_SQRT", + "WARMUP_STABLE_DECAY", ], value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value), ) - + # Initialize the learning rate scheduler type dropdown + self.lr_scheduler_type = gr.Dropdown( + label="LR Scheduler type", + info="(Optional) custom scheduler module name", + choices=[ + "", + "CosineAnnealingLR", + ], + value=self.config.get("basic.lr_scheduler_type", ""), + allow_custom_value=True, + ) # Initialize the optimizer dropdown self.optimizer = gr.Dropdown( label="Optimizer", choices=[ "AdamW", + "AdamWScheduleFree", "AdamW8bit", "Adafactor", + "bitsandbytes.optim.AdEMAMix8bit", + "bitsandbytes.optim.PagedAdEMAMix8bit", "DAdaptation", "DAdaptAdaGrad", "DAdaptAdam", @@ -192,6 +214,7 @@ def init_lr_and_optimizer_controls(self) -> None: "Prodigy", "SGDNesterov", "SGDNesterov8bit", + "SGDScheduleFree", ], value=self.config.get("basic.optimizer", "AdamW8bit"), interactive=True, @@ -240,7 +263,7 @@ def init_learning_rate_controls(self) -> None: self.learning_rate = gr.Number( label=lr_label, value=self.config.get("basic.learning_rate", self.learning_rate_value), - minimum=0, + minimum=-1, maximum=1, info="Set to 0 to not train the Unet", ) @@ -251,7 +274,7 @@ def init_learning_rate_controls(self) -> None: "basic.learning_rate_te", self.learning_rate_value ), visible=self.finetuning or self.dreambooth, - minimum=0, + minimum=-1, maximum=1, info="Set to 0 to not train the Text Encoder", ) @@ -262,7 +285,7 @@ def init_learning_rate_controls(self) -> None: "basic.learning_rate_te1", self.learning_rate_value ), visible=False, - minimum=0, + minimum=-1, maximum=1, info="Set to 0 to not train the Text Encoder 1", ) @@ -273,7 +296,7 @@ def init_learning_rate_controls(self) -> None: "basic.learning_rate_te2", self.learning_rate_value ), visible=False, - minimum=0, + minimum=-1, maximum=1, info="Set to 0 to not train the Text Encoder 2", ) @@ -285,25 +308,37 @@ def init_learning_rate_controls(self) -> None: maximum=100, step=1, ) + # Initialize the learning rate warmup steps override + self.lr_warmup_steps = gr.Number( + label="LR warmup steps (override)", + value=self.config.get("basic.lr_warmup_steps", self.lr_warmup_steps_value), + minimum=0, + step=1, + ) - def lr_scheduler_changed(scheduler, value): + def lr_scheduler_changed(scheduler, value, value_lr_warmup_steps): if scheduler == "constant": self.old_lr_warmup = value + self.old_lr_warmup_steps = value_lr_warmup_steps value = 0 + value_lr_warmup_steps = 0 interactive=False info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..." else: if self.old_lr_warmup != 0: value = self.old_lr_warmup self.old_lr_warmup = 0 + if self.old_lr_warmup_steps != 0: + value_lr_warmup_steps = self.old_lr_warmup_steps + self.old_lr_warmup_steps = 0 interactive=True info="" - return gr.Slider(value=value, interactive=interactive, info=info) + return gr.Slider(value=value, interactive=interactive, info=info), gr.Number(value=value_lr_warmup_steps, interactive=interactive, info=info) self.lr_scheduler.change( lr_scheduler_changed, - inputs=[self.lr_scheduler, self.lr_warmup], - outputs=self.lr_warmup, + inputs=[self.lr_scheduler, self.lr_warmup, self.lr_warmup_steps], + outputs=[self.lr_warmup, self.lr_warmup_steps], ) def init_scheduler_controls(self) -> None: diff --git a/kohya_gui/class_command_executor.py b/kohya_gui/class_command_executor.py index f18e97a32..ba6f9c8e1 100644 --- a/kohya_gui/class_command_executor.py +++ b/kohya_gui/class_command_executor.py @@ -48,7 +48,7 @@ def execute_command(self, run_cmd: str, **kwargs): # Execute the command securely self.process = subprocess.Popen(run_cmd, **kwargs) - log.info("Command executed.") + log.debug("Command executed.") def kill_command(self): """ diff --git a/kohya_gui/class_flux1.py b/kohya_gui/class_flux1.py new file mode 100644 index 000000000..547e51934 --- /dev/null +++ b/kohya_gui/class_flux1.py @@ -0,0 +1,336 @@ +import gradio as gr +from typing import Tuple +from .common_gui import ( + get_any_file_path, + document_symbol, +) + + +class flux1Training: + def __init__( + self, + headless: bool = False, + finetuning: bool = False, + training_type: str = "", + config: dict = {}, + flux1_checkbox: gr.Checkbox = False, + ) -> None: + self.headless = headless + self.finetuning = finetuning + self.training_type = training_type + self.config = config + self.flux1_checkbox = flux1_checkbox + + # Define the behavior for changing noise offset type. + def noise_offset_type_change( + noise_offset_type: str, + ) -> Tuple[gr.Group, gr.Group]: + if noise_offset_type == "Original": + return (gr.Group(visible=True), gr.Group(visible=False)) + else: + return (gr.Group(visible=False), gr.Group(visible=True)) + + with gr.Accordion( + "Flux.1", open=True, visible=False, elem_classes=["flux1_background"] + ) as flux1_accordion: + with gr.Group(): + with gr.Row(): + self.ae = gr.Textbox( + label="VAE Path", + placeholder="Path to VAE model", + value=self.config.get("flux1.ae", ""), + interactive=True, + ) + self.ae_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.ae_button.click( + get_any_file_path, + outputs=self.ae, + show_progress=False, + ) + + self.clip_l = gr.Textbox( + label="CLIP-L Path", + placeholder="Path to CLIP-L model", + value=self.config.get("flux1.clip_l", ""), + interactive=True, + ) + self.clip_l_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.clip_l_button.click( + get_any_file_path, + outputs=self.clip_l, + show_progress=False, + ) + + self.t5xxl = gr.Textbox( + label="T5-XXL Path", + placeholder="Path to T5-XXL model", + value=self.config.get("flux1.t5xxl", ""), + interactive=True, + ) + self.t5xxl_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.t5xxl_button.click( + get_any_file_path, + outputs=self.t5xxl, + show_progress=False, + ) + + with gr.Row(): + + self.discrete_flow_shift = gr.Number( + label="Discrete Flow Shift", + value=self.config.get("flux1.discrete_flow_shift", 3.0), + info="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0", + minimum=-1024, + maximum=1024, + step=0.01, + interactive=True, + ) + self.model_prediction_type = gr.Dropdown( + label="Model Prediction Type", + choices=["raw", "additive", "sigma_scaled"], + value=self.config.get( + "flux1.timestep_sampling", "sigma_scaled" + ), + interactive=True, + ) + self.timestep_sampling = gr.Dropdown( + label="Timestep Sampling", + choices=["flux_shift", "sigma", "shift", "sigmoid", "uniform"], + value=self.config.get("flux1.timestep_sampling", "sigma"), + interactive=True, + ) + self.apply_t5_attn_mask = gr.Checkbox( + label="Apply T5 Attention Mask", + value=self.config.get("flux1.apply_t5_attn_mask", False), + info="Apply attention mask to T5-XXL encode and FLUX double blocks ", + interactive=True, + ) + with gr.Row(visible=True if not finetuning else False): + self.split_mode = gr.Checkbox( + label="Split Mode", + value=self.config.get("flux1.split_mode", False), + info="Split mode for Flux1", + interactive=True, + ) + self.train_blocks = gr.Dropdown( + label="Train Blocks", + choices=["all", "double", "single"], + value=self.config.get("flux1.train_blocks", "all"), + interactive=True, + ) + self.split_qkv = gr.Checkbox( + label="Split QKV", + value=self.config.get("flux1.split_qkv", False), + info="Split the projection layers of q/k/v/txt in the attention", + interactive=True, + ) + self.train_t5xxl = gr.Checkbox( + label="Train T5-XXL", + value=self.config.get("flux1.train_t5xxl", False), + info="Train T5-XXL model", + interactive=True, + ) + self.cpu_offload_checkpointing = gr.Checkbox( + label="CPU Offload Checkpointing", + value=self.config.get("flux1.cpu_offload_checkpointing", False), + info="[Experimental] Enable offloading of tensors to CPU during checkpointing", + interactive=True, + ) + with gr.Row(): + self.guidance_scale = gr.Number( + label="Guidance Scale", + value=self.config.get("flux1.guidance_scale", 3.5), + info="Guidance scale for Flux1", + minimum=0, + maximum=1024, + step=0.1, + interactive=True, + ) + self.t5xxl_max_token_length = gr.Number( + label="T5-XXL Max Token Length", + value=self.config.get("flux1.t5xxl_max_token_length", 512), + info="Max token length for T5-XXL", + minimum=0, + maximum=4096, + step=1, + interactive=True, + ) + self.enable_all_linear = gr.Checkbox( + label="Enable All Linear", + value=self.config.get("flux1.enable_all_linear", False), + info="(Only applicable to 'FLux1 OFT' LoRA) Target all linear connections in the MLP layer. The default is False, which targets only attention.", + interactive=True, + ) + + with gr.Row(): + self.flux1_cache_text_encoder_outputs = gr.Checkbox( + label="Cache Text Encoder Outputs", + value=self.config.get( + "flux1.cache_text_encoder_outputs", False + ), + info="Cache text encoder outputs to speed up inference", + interactive=True, + ) + self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox( + label="Cache Text Encoder Outputs to Disk", + value=self.config.get( + "flux1.cache_text_encoder_outputs_to_disk", False + ), + info="Cache text encoder outputs to disk to speed up inference", + interactive=True, + ) + self.mem_eff_save = gr.Checkbox( + label="Memory Efficient Save", + value=self.config.get("flux1.mem_eff_save", False), + info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.", + interactive=True, + ) + + with gr.Row(visible=True if finetuning else False): + self.blocks_to_swap = gr.Slider( + label="Blocks to swap", + value=self.config.get("flux1.blocks_to_swap", 0), + info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.", + minimum=0, + maximum=57, + step=1, + interactive=True, + ) + self.single_blocks_to_swap = gr.Slider( + label="Single Blocks to swap (depercated)", + value=self.config.get("flux1.single_blocks_to_swap", 0), + info="[Experimental] Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes.", + minimum=0, + maximum=19, + step=1, + interactive=True, + ) + self.double_blocks_to_swap = gr.Slider( + label="Double Blocks to swap (depercated)", + value=self.config.get("flux1.double_blocks_to_swap", 0), + info="[Experimental] Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes.", + minimum=0, + maximum=38, + step=1, + interactive=True, + ) + + with gr.Row(visible=True if finetuning else False): + self.blockwise_fused_optimizers = gr.Checkbox( + label="Blockwise Fused Optimizer", + value=self.config.get( + "flux1.blockwise_fused_optimizers", False + ), + info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.", + interactive=True, + ) + self.cpu_offload_checkpointing = gr.Checkbox( + label="CPU Offload Checkpointing", + value=self.config.get("flux1.cpu_offload_checkpointing", False), + info="[Experimental] Enable offloading of tensors to CPU during checkpointing", + interactive=True, + ) + self.flux_fused_backward_pass = gr.Checkbox( + label="Fused Backward Pass", + value=self.config.get("flux1.fused_backward_pass", False), + info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.", + interactive=True, + ) + + with gr.Accordion( + "Blocks to train", + open=True, + visible=False if finetuning else True, + elem_classes=["flux1_blocks_to_train_background"], + ): + with gr.Row(): + self.train_double_block_indices = gr.Textbox( + label="train_double_block_indices", + info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of double blocks is 19.", + value=self.config.get("flux1.train_double_block_indices", "all"), + interactive=True, + ) + self.train_single_block_indices = gr.Textbox( + label="train_single_block_indices", + info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of single blocks is 38.", + value=self.config.get("flux1.train_single_block_indices", "all"), + interactive=True, + ) + + with gr.Accordion( + "Rank for layers", + open=False, + visible=False if finetuning else True, + elem_classes=["flux1_rank_layers_background"], + ): + with gr.Row(): + self.img_attn_dim = gr.Textbox( + label="img_attn_dim", + value=self.config.get("flux1.img_attn_dim", ""), + interactive=True, + ) + self.img_mlp_dim = gr.Textbox( + label="img_mlp_dim", + value=self.config.get("flux1.img_mlp_dim", ""), + interactive=True, + ) + self.img_mod_dim = gr.Textbox( + label="img_mod_dim", + value=self.config.get("flux1.img_mod_dim", ""), + interactive=True, + ) + self.single_dim = gr.Textbox( + label="single_dim", + value=self.config.get("flux1.single_dim", ""), + interactive=True, + ) + with gr.Row(): + self.txt_attn_dim = gr.Textbox( + label="txt_attn_dim", + value=self.config.get("flux1.txt_attn_dim", ""), + interactive=True, + ) + self.txt_mlp_dim = gr.Textbox( + label="txt_mlp_dim", + value=self.config.get("flux1.txt_mlp_dim", ""), + interactive=True, + ) + self.txt_mod_dim = gr.Textbox( + label="txt_mod_dim", + value=self.config.get("flux1.txt_mod_dim", ""), + interactive=True, + ) + self.single_mod_dim = gr.Textbox( + label="single_mod_dim", + value=self.config.get("flux1.single_mod_dim", ""), + interactive=True, + ) + with gr.Row(): + self.in_dims = gr.Textbox( + label="in_dims", + value=self.config.get("flux1.in_dims", ""), + placeholder="e.g., [4,0,0,0,4]", + info="Each number corresponds to img_in, time_in, vector_in, guidance_in, txt_in. The above example applies LoRA to all conditioning layers, with rank 4 for img_in, 2 for time_in, vector_in, guidance_in, and 4 for txt_in.", + interactive=True, + ) + + self.flux1_checkbox.change( + lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox), + inputs=[self.flux1_checkbox], + outputs=[flux1_accordion], + ) diff --git a/kohya_gui/class_lora_tab.py b/kohya_gui/class_lora_tab.py index 487ff5cbc..c2b9a8016 100644 --- a/kohya_gui/class_lora_tab.py +++ b/kohya_gui/class_lora_tab.py @@ -4,10 +4,12 @@ from .verify_lora_gui import gradio_verify_lora_tab from .resize_lora_gui import gradio_resize_lora_tab from .extract_lora_gui import gradio_extract_lora_tab +from .flux_extract_lora_gui import gradio_flux_extract_lora_tab from .convert_lcm_gui import gradio_convert_lcm_tab from .extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from .extract_lora_from_dylora_gui import gradio_extract_dylora_tab from .merge_lycoris_gui import gradio_merge_lycoris_tab +from .flux_merge_lora_gui import GradioFluxMergeLoRaTab class LoRATools: @@ -19,9 +21,11 @@ def __init__( gradio_extract_dylora_tab(headless=headless) gradio_convert_lcm_tab(headless=headless) gradio_extract_lora_tab(headless=headless) + gradio_flux_extract_lora_tab(headless=headless) gradio_extract_lycoris_locon_tab(headless=headless) gradio_merge_lora_tab = GradioMergeLoRaTab() gradio_merge_lycoris_tab(headless=headless) gradio_svd_merge_lora_tab(headless=headless) gradio_resize_lora_tab(headless=headless) gradio_verify_lora_tab(headless=headless) + GradioFluxMergeLoRaTab(headless=headless) diff --git a/kohya_gui/class_sample_images.py b/kohya_gui/class_sample_images.py index 8f69a2ec6..807c8b449 100644 --- a/kohya_gui/class_sample_images.py +++ b/kohya_gui/class_sample_images.py @@ -28,7 +28,10 @@ def create_prompt_file(sample_prompts, output_dir): Returns: str: The path to the prompt file. """ - sample_prompts_path = os.path.join(output_dir, "prompt.txt") + sample_prompts_path = os.path.join(output_dir, "sample/prompt.txt") + + if not os.path.exists(os.path.dirname(sample_prompts_path)): + os.makedirs(os.path.dirname(sample_prompts_path)) with open(sample_prompts_path, "w", encoding="utf-8") as f: f.write(sample_prompts) diff --git a/kohya_gui/class_sd3.py b/kohya_gui/class_sd3.py new file mode 100644 index 000000000..d5dae715f --- /dev/null +++ b/kohya_gui/class_sd3.py @@ -0,0 +1,203 @@ +import gradio as gr +from typing import Tuple +from .common_gui import ( + get_folder_path, + get_any_file_path, + list_files, + list_dirs, + create_refresh_button, + document_symbol, +) + + +class sd3Training: + """ + This class configures and initializes the advanced training settings for a machine learning model, + including options for headless operation, fine-tuning, training type selection, and default directory paths. + + Attributes: + headless (bool): If True, run without the Gradio interface. + finetuning (bool): If True, enables fine-tuning of the model. + training_type (str): Specifies the type of training to perform. + no_token_padding (gr.Checkbox): Checkbox to disable token padding. + gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps. + weighted_captions (gr.Checkbox): Checkbox to enable weighted captions. + """ + + def __init__( + self, + headless: bool = False, + finetuning: bool = False, + training_type: str = "", + config: dict = {}, + sd3_checkbox: gr.Checkbox = False, + ) -> None: + """ + Initializes the AdvancedTraining class with given settings. + + Parameters: + headless (bool): Run in headless mode without GUI. + finetuning (bool): Enable model fine-tuning. + training_type (str): The type of training to be performed. + config (dict): Configuration options for the training process. + """ + self.headless = headless + self.finetuning = finetuning + self.training_type = training_type + self.config = config + self.sd3_checkbox = sd3_checkbox + + # Define the behavior for changing noise offset type. + def noise_offset_type_change( + noise_offset_type: str, + ) -> Tuple[gr.Group, gr.Group]: + """ + Returns a tuple of Gradio Groups with visibility set based on the noise offset type. + + Parameters: + noise_offset_type (str): The selected noise offset type. + + Returns: + Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set. + """ + if noise_offset_type == "Original": + return (gr.Group(visible=True), gr.Group(visible=False)) + else: + return (gr.Group(visible=False), gr.Group(visible=True)) + + with gr.Accordion( + "SD3", open=False, elem_id="sd3_tab", visible=False + ) as sd3_accordion: + with gr.Group(): + gr.Markdown("### SD3 Specific Parameters") + with gr.Row(): + self.weighting_scheme = gr.Dropdown( + label="Weighting Scheme", + choices=["logit_normal", "sigma_sqrt", "mode", "cosmap"], + value=self.config.get("sd3.weighting_scheme", "logit_normal"), + interactive=True, + ) + self.logit_mean = gr.Number( + label="Logit Mean", + value=self.config.get("sd3.logit_mean", 0.0), + interactive=True, + ) + self.logit_std = gr.Number( + label="Logit Std", + value=self.config.get("sd3.logit_std", 1.0), + interactive=True, + ) + self.mode_scale = gr.Number( + label="Mode Scale", + value=self.config.get("sd3.mode_scale", 1.29), + interactive=True, + ) + + with gr.Row(): + self.clip_l = gr.Textbox( + label="CLIP-L Path", + placeholder="Path to CLIP-L model", + value=self.config.get("sd3.clip_l", ""), + interactive=True, + ) + self.clip_l_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.clip_l_button.click( + get_any_file_path, + outputs=self.clip_l, + show_progress=False, + ) + + self.clip_g = gr.Textbox( + label="CLIP-G Path", + placeholder="Path to CLIP-G model", + value=self.config.get("sd3.clip_g", ""), + interactive=True, + ) + self.clip_g_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.clip_g_button.click( + get_any_file_path, + outputs=self.clip_g, + show_progress=False, + ) + + self.t5xxl = gr.Textbox( + label="T5-XXL Path", + placeholder="Path to T5-XXL model", + value=self.config.get("sd3.t5xxl", ""), + interactive=True, + ) + self.t5xxl_button = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + interactive=True, + ) + self.t5xxl_button.click( + get_any_file_path, + outputs=self.t5xxl, + show_progress=False, + ) + + with gr.Row(): + self.save_clip = gr.Checkbox( + label="Save CLIP models", + value=self.config.get("sd3.save_clip", False), + interactive=True, + ) + self.save_t5xxl = gr.Checkbox( + label="Save T5-XXL model", + value=self.config.get("sd3.save_t5xxl", False), + interactive=True, + ) + + with gr.Row(): + self.t5xxl_device = gr.Textbox( + label="T5-XXL Device", + placeholder="Device for T5-XXL (e.g., cuda:0)", + value=self.config.get("sd3.t5xxl_device", ""), + interactive=True, + ) + self.t5xxl_dtype = gr.Dropdown( + label="T5-XXL Dtype", + choices=["float32", "fp16", "bf16"], + value=self.config.get("sd3.t5xxl_dtype", "bf16"), + interactive=True, + ) + self.sd3_text_encoder_batch_size = gr.Number( + label="Text Encoder Batch Size", + value=self.config.get("sd3.text_encoder_batch_size", 1), + minimum=1, + maximum=1024, + step=1, + interactive=True, + ) + self.sd3_cache_text_encoder_outputs = gr.Checkbox( + label="Cache Text Encoder Outputs", + value=self.config.get("sd3.cache_text_encoder_outputs", False), + info="Cache text encoder outputs to speed up inference", + interactive=True, + ) + self.sd3_cache_text_encoder_outputs_to_disk = gr.Checkbox( + label="Cache Text Encoder Outputs to Disk", + value=self.config.get( + "sd3.cache_text_encoder_outputs_to_disk", False + ), + info="Cache text encoder outputs to disk to speed up inference", + interactive=True, + ) + + self.sd3_checkbox.change( + lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox), + inputs=[self.sd3_checkbox], + outputs=[sd3_accordion], + ) diff --git a/kohya_gui/class_sdxl_parameters.py b/kohya_gui/class_sdxl_parameters.py index b0098d2a3..e1141668c 100644 --- a/kohya_gui/class_sdxl_parameters.py +++ b/kohya_gui/class_sdxl_parameters.py @@ -7,10 +7,12 @@ def __init__( sdxl_checkbox: gr.Checkbox, show_sdxl_cache_text_encoder_outputs: bool = True, config: KohyaSSGUIConfig = {}, + trainer: str = "", ): self.sdxl_checkbox = sdxl_checkbox self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs self.config = config + self.trainer = trainer self.initialize_accordion() @@ -30,6 +32,41 @@ def initialize_accordion(self): info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.", value=self.config.get("sdxl.sdxl_no_half_vae", False), ) + self.fused_backward_pass = gr.Checkbox( + label="Fused backward pass", + info="Enable fused backward pass. This option is useful to reduce the GPU memory usage. Can't be used if Fused optimizer groups is > 0. Only AdaFactor is supported", + value=self.config.get("sdxl.fused_backward_pass", False), + visible=self.trainer == "finetune" or self.trainer == "dreambooth", + ) + self.fused_optimizer_groups = gr.Number( + label="Fused optimizer groups", + info="Number of optimizer groups to fuse. This option is useful to reduce the GPU memory usage. Can't be used if Fused backward pass is enabled. Since the effect is limited to a certain number, it is recommended to specify 4-10.", + value=self.config.get("sdxl.fused_optimizer_groups", 0), + minimum=0, + step=1, + visible=self.trainer == "finetune" or self.trainer == "dreambooth", + ) + self.disable_mmap_load_safetensors = gr.Checkbox( + label="Disable mmap load safe tensors", + info="Disable memory mapping when loading the model's .safetensors in SDXL.", + value=self.config.get("sdxl.disable_mmap_load_safetensors", False), + ) + + self.fused_backward_pass.change( + lambda fused_backward_pass: gr.Number( + interactive=not fused_backward_pass + ), + inputs=[self.fused_backward_pass], + outputs=[self.fused_optimizer_groups], + ) + self.fused_optimizer_groups.change( + lambda fused_optimizer_groups: gr.Checkbox( + interactive=fused_optimizer_groups == 0 + ), + inputs=[self.fused_optimizer_groups], + outputs=[self.fused_backward_pass], + ) + self.sdxl_checkbox.change( lambda sdxl_checkbox: gr.Accordion(visible=sdxl_checkbox), diff --git a/kohya_gui/class_source_model.py b/kohya_gui/class_source_model.py index 4b081f677..f9ece6577 100644 --- a/kohya_gui/class_source_model.py +++ b/kohya_gui/class_source_model.py @@ -26,7 +26,6 @@ "stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned", "stabilityai/stable-diffusion-2-1", "stabilityai/stable-diffusion-2", - "runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4", ] @@ -245,19 +244,88 @@ def list_dataset_config_dirs(path: str) -> list: with gr.Column(): with gr.Row(): self.v2 = gr.Checkbox( - label="v2", value=False, visible=False, min_width=60 + label="v2", value=False, visible=False, min_width=60, + interactive=True, ) self.v_parameterization = gr.Checkbox( label="v_parameterization", value=False, visible=False, min_width=130, + interactive=True, ) self.sdxl_checkbox = gr.Checkbox( label="SDXL", value=False, visible=False, min_width=60, + interactive=True, + ) + self.sd3_checkbox = gr.Checkbox( + label="SD3", + value=False, + visible=False, + min_width=60, + interactive=True, + ) + self.flux1_checkbox = gr.Checkbox( + label="Flux.1", + value=False, + visible=False, + min_width=60, + interactive=True, + ) + + def toggle_checkboxes(v2, v_parameterization, sdxl_checkbox, sd3_checkbox, flux1_checkbox): + # Check if all checkboxes are unchecked + if not v2 and not v_parameterization and not sdxl_checkbox and not sd3_checkbox and not flux1_checkbox: + # If all unchecked, return new interactive checkboxes + return ( + gr.Checkbox(interactive=True), # v2 checkbox + gr.Checkbox(interactive=True), # v_parameterization checkbox + gr.Checkbox(interactive=True), # sdxl_checkbox + gr.Checkbox(interactive=True), # sd3_checkbox + gr.Checkbox(interactive=True), # sd3_checkbox + ) + else: + # If any checkbox is checked, return checkboxes with current interactive state + return ( + gr.Checkbox(interactive=v2), # v2 checkbox + gr.Checkbox(interactive=v_parameterization), # v_parameterization checkbox + gr.Checkbox(interactive=sdxl_checkbox), # sdxl_checkbox + gr.Checkbox(interactive=sd3_checkbox), # sd3_checkbox + gr.Checkbox(interactive=flux1_checkbox), # flux1_checkbox + ) + + self.v2.change( + fn=toggle_checkboxes, + inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + show_progress=False, + ) + self.v_parameterization.change( + fn=toggle_checkboxes, + inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + show_progress=False, + ) + self.sdxl_checkbox.change( + fn=toggle_checkboxes, + inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + show_progress=False, + ) + self.sd3_checkbox.change( + fn=toggle_checkboxes, + inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + show_progress=False, + ) + self.flux1_checkbox.change( + fn=toggle_checkboxes, + inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox], + show_progress=False, ) with gr.Column(): gr.Group(visible=False) @@ -294,6 +362,8 @@ def list_dataset_config_dirs(path: str) -> list: self.v2, self.v_parameterization, self.sdxl_checkbox, + self.sd3_checkbox, + self.flux1_checkbox, ], show_progress=False, ) diff --git a/kohya_gui/class_tensorboard.py b/kohya_gui/class_tensorboard.py index b9a9a9c4b..001c894da 100644 --- a/kohya_gui/class_tensorboard.py +++ b/kohya_gui/class_tensorboard.py @@ -20,6 +20,7 @@ class TensorboardManager: DEFAULT_TENSORBOARD_PORT = 6006 + DEFAULT_TENSORBOARD_HOST = "0.0.0.0" def __init__(self, logging_dir, headless: bool = False, wait_time=5): self.logging_dir = logging_dir @@ -29,6 +30,9 @@ def __init__(self, logging_dir, headless: bool = False, wait_time=5): self.tensorboard_port = os.environ.get( "TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT ) + self.tensorboard_host = os.environ.get( + "TENSORBOARD_HOST", self.DEFAULT_TENSORBOARD_HOST + ) self.log = setup_logging() self.thread = None self.stop_event = Event() @@ -64,7 +68,7 @@ def start_tensorboard(self, logging_dir=None): "--logdir", logging_dir, "--host", - "0.0.0.0", + self.tensorboard_host, "--port", str(self.tensorboard_port), ] diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 0ca334eb6..8823cb78f 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -5,6 +5,7 @@ from easygui import msgbox, ynbox from typing import Optional from .custom_logging import setup_logging +from .sd_modeltype import SDModelType import os import re @@ -327,7 +328,6 @@ def update_my_data(my_data): # Convert values to int if they are strings for key in [ - "adaptive_noise_scale", "clip_skip", "epoch", "gradient_accumulation_steps", @@ -379,7 +379,13 @@ def update_my_data(my_data): my_data[key] = int(75) # Convert values to float if they are strings, correctly handling float representations - for key in ["noise_offset", "learning_rate", "text_encoder_lr", "unet_lr"]: + for key in [ + "adaptive_noise_scale", + "noise_offset", + "learning_rate", + "text_encoder_lr", + "unet_lr", + ]: value = my_data.get(key) if value is not None: try: @@ -956,11 +962,15 @@ def set_pretrained_model_name_or_path_input( v2 = gr.Checkbox(value=False, visible=False) v_parameterization = gr.Checkbox(value=False, visible=False) sdxl = gr.Checkbox(value=True, visible=False) + sd3 = gr.Checkbox(value=False, visible=False) + flux1 = gr.Checkbox(value=False, visible=False) return ( gr.Dropdown(), v2, v_parameterization, sdxl, + sd3, + flux1, ) # Check if the given pretrained_model_name_or_path is in the list of V2 base models @@ -969,11 +979,15 @@ def set_pretrained_model_name_or_path_input( v2 = gr.Checkbox(value=True, visible=False) v_parameterization = gr.Checkbox(value=False, visible=False) sdxl = gr.Checkbox(value=False, visible=False) + sd3 = gr.Checkbox(value=False, visible=False) + flux1 = gr.Checkbox(value=False, visible=False) return ( gr.Dropdown(), v2, v_parameterization, sdxl, + sd3, + flux1, ) # Check if the given pretrained_model_name_or_path is in the list of V parameterization models @@ -984,11 +998,15 @@ def set_pretrained_model_name_or_path_input( v2 = gr.Checkbox(value=True, visible=False) v_parameterization = gr.Checkbox(value=True, visible=False) sdxl = gr.Checkbox(value=False, visible=False) + sd3 = gr.Checkbox(value=False, visible=False) + flux1 = gr.Checkbox(value=False, visible=False) return ( gr.Dropdown(), v2, v_parameterization, sdxl, + sd3, + flux1, ) # Check if the given pretrained_model_name_or_path is in the list of V1 models @@ -997,17 +1015,32 @@ def set_pretrained_model_name_or_path_input( v2 = gr.Checkbox(value=False, visible=False) v_parameterization = gr.Checkbox(value=False, visible=False) sdxl = gr.Checkbox(value=False, visible=False) + sd3 = gr.Checkbox(value=False, visible=False) + flux1 = gr.Checkbox(value=False, visible=False) return ( gr.Dropdown(), v2, v_parameterization, sdxl, + sd3, + flux1, ) # Check if the model_list is set to 'custom' v2 = gr.Checkbox(visible=True) v_parameterization = gr.Checkbox(visible=True) sdxl = gr.Checkbox(visible=True) + sd3 = gr.Checkbox(visible=True) + flux1 = gr.Checkbox(visible=True) + + # Auto-detect model type if safetensors file path is given + if pretrained_model_name_or_path.lower().endswith(".safetensors"): + detect = SDModelType(pretrained_model_name_or_path) + v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True) + sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True) + sd3 = gr.Checkbox(value=detect.Is_SD3(), visible=True) + flux1 = gr.Checkbox(value=detect.Is_FLUX1(), visible=True) + #TODO: v_parameterization # If a refresh method is provided, use it to update the choices for the Dropdown widget if refresh_method is not None: @@ -1021,6 +1054,8 @@ def set_pretrained_model_name_or_path_input( v2, v_parameterization, sdxl, + sd3, + flux1, ) @@ -1369,7 +1404,11 @@ def validate_file_path(file_path: str) -> bool: return True -def validate_folder_path(folder_path: str, can_be_written_to: bool = False, create_if_not_exists: bool = False) -> bool: +def validate_folder_path( + folder_path: str, + can_be_written_to: bool = False, + create_if_not_exists: bool = False, +) -> bool: if folder_path == "": return True msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..." @@ -1387,6 +1426,7 @@ def validate_folder_path(folder_path: str, can_be_written_to: bool = False, crea log.info(f"{msg} SUCCESS") return True + def validate_toml_file(file_path: str) -> bool: if file_path == "": return True @@ -1394,7 +1434,7 @@ def validate_toml_file(file_path: str) -> bool: if not os.path.isfile(file_path): log.error(f"{msg} FAILED: does not exist") return False - + try: toml.load(file_path) except: @@ -1425,11 +1465,14 @@ def validate_model_path(pretrained_model_name_or_path: str) -> bool: log.info(f"{msg} SUCCESS") else: # If not one of the default models, check if it's a valid local path - if not validate_file_path(pretrained_model_name_or_path) and not validate_folder_path(pretrained_model_name_or_path): + if not validate_file_path( + pretrained_model_name_or_path + ) and not validate_folder_path(pretrained_model_name_or_path): log.info(f"{msg} FAILURE: not a valid file or folder") return False return True + def is_file_writable(file_path: str) -> bool: """ Checks if a file is writable. @@ -1450,8 +1493,9 @@ def is_file_writable(file_path: str) -> bool: pass # If the file can be opened, it is considered writable return True - except IOError: + except IOError as e: # If an IOError occurs, the file cannot be written to + log.info(f"Error: {e}. File '{file_path}' is not writable.") return False @@ -1462,7 +1506,7 @@ def print_command_and_toml(run_cmd, tmpfilename): # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(command_to_run) + print(command_to_run) print("") log.info(f"Showing toml config file: {tmpfilename}") @@ -1489,10 +1533,11 @@ def validate_args_setting(input_string): ) return False + def setup_environment(): env = os.environ.copy() env["PYTHONPATH"] = ( - fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" diff --git a/kohya_gui/dataset_balancing_gui.py b/kohya_gui/dataset_balancing_gui.py index 8d644d1c1..eb6c2ff61 100644 --- a/kohya_gui/dataset_balancing_gui.py +++ b/kohya_gui/dataset_balancing_gui.py @@ -10,6 +10,11 @@ log = setup_logging() +import os +import re +import logging as log +from easygui import msgbox + def dataset_balancing(concept_repeats, folder, insecure): if not concept_repeats > 0: @@ -78,7 +83,11 @@ def dataset_balancing(concept_repeats, folder, insecure): old_name = os.path.join(folder, subdir) new_name = os.path.join(folder, f"{repeats}_{subdir}") - os.rename(old_name, new_name) + # Check if the new folder name already exists + if os.path.exists(new_name): + log.warning(f"Destination folder {new_name} already exists. Skipping...") + else: + os.rename(old_name, new_name) else: log.info( f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..." @@ -87,6 +96,7 @@ def dataset_balancing(concept_repeats, folder, insecure): msgbox("Dataset balancing completed...") + def warning(insecure): if insecure: if boolbox( diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index a38230a21..55454b526 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -17,7 +17,9 @@ SaveConfigFile, scriptdir, update_my_data, - validate_file_path, validate_folder_path, validate_model_path, + validate_file_path, + validate_folder_path, + validate_model_path, validate_args_setting, setup_environment, ) @@ -27,10 +29,13 @@ from .class_source_model import SourceModel from .class_basic_training import BasicTraining from .class_advanced_training import AdvancedTraining +from .class_sd3 import sd3Training from .class_folders import Folders from .class_command_executor import CommandExecutor from .class_huggingface import HuggingFace from .class_metadata import MetaData +from .class_sdxl_parameters import SDXLParameters +from .class_flux1 import flux1Training from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -60,6 +65,7 @@ def save_configuration( v2, v_parameterization, sdxl, + flux1_checkbox, logging_dir, train_data_dir, reg_data_dir, @@ -72,6 +78,7 @@ def save_configuration( learning_rate_te2, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, @@ -84,6 +91,7 @@ def save_configuration( caption_extension, enable_bucket, gradient_checkpointing, + fp8_base, full_fp16, full_bf16, no_token_padding, @@ -134,6 +142,7 @@ def save_configuration( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -156,12 +165,21 @@ def save_configuration( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, + fused_backward_pass, + fused_optimizer_groups, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, min_timestep, max_timestep, debiased_estimation_loss, @@ -178,6 +196,43 @@ def save_configuration( metadata_license, metadata_tags, metadata_title, + # SD3 parameters + sd3_cache_text_encoder_outputs, + sd3_cache_text_encoder_outputs_to_disk, + clip_g, + clip_l, + logit_mean, + logit_std, + mode_scale, + save_clip, + save_t5xxl, + t5xxl, + t5xxl_device, + t5xxl_dtype, + sd3_text_encoder_batch_size, + weighting_scheme, + sd3_checkbox, + # Flux.1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + flux1_clip_l, + flux1_t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + guidance_scale, + blockwise_fused_optimizers, + flux_fused_backward_pass, + cpu_offload_checkpointing, + blocks_to_swap, + single_blocks_to_swap, + double_blocks_to_swap, + mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -218,6 +273,7 @@ def open_configuration( v2, v_parameterization, sdxl, + flux1_checkbox, logging_dir, train_data_dir, reg_data_dir, @@ -230,6 +286,7 @@ def open_configuration( learning_rate_te2, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, @@ -242,6 +299,7 @@ def open_configuration( caption_extension, enable_bucket, gradient_checkpointing, + fp8_base, full_fp16, full_bf16, no_token_padding, @@ -292,6 +350,7 @@ def open_configuration( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -314,12 +373,21 @@ def open_configuration( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, + fused_backward_pass, + fused_optimizer_groups, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, min_timestep, max_timestep, debiased_estimation_loss, @@ -336,6 +404,43 @@ def open_configuration( metadata_license, metadata_tags, metadata_title, + # SD3 parameters + sd3_cache_text_encoder_outputs, + sd3_cache_text_encoder_outputs_to_disk, + clip_g, + clip_l, + logit_mean, + logit_std, + mode_scale, + save_clip, + save_t5xxl, + t5xxl, + t5xxl_device, + t5xxl_dtype, + sd3_text_encoder_batch_size, + weighting_scheme, + sd3_checkbox, + # Flux.1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + flux1_clip_l, + flux1_t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + guidance_scale, + blockwise_fused_optimizers, + flux_fused_backward_pass, + cpu_offload_checkpointing, + blocks_to_swap, + single_blocks_to_swap, + double_blocks_to_swap, + mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -371,6 +476,7 @@ def train_model( v2, v_parameterization, sdxl, + flux1_checkbox, logging_dir, train_data_dir, reg_data_dir, @@ -383,6 +489,7 @@ def train_model( learning_rate_te2, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, @@ -395,6 +502,7 @@ def train_model( caption_extension, enable_bucket, gradient_checkpointing, + fp8_base, full_fp16, full_bf16, no_token_padding, @@ -445,6 +553,7 @@ def train_model( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -467,12 +576,21 @@ def train_model( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, + fused_backward_pass, + fused_optimizer_groups, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, min_timestep, max_timestep, debiased_estimation_loss, @@ -489,6 +607,43 @@ def train_model( metadata_license, metadata_tags, metadata_title, + # SD3 parameters + sd3_cache_text_encoder_outputs, + sd3_cache_text_encoder_outputs_to_disk, + clip_g, + clip_l, + logit_mean, + logit_std, + mode_scale, + save_clip, + save_t5xxl, + t5xxl, + t5xxl_device, + t5xxl_dtype, + sd3_text_encoder_batch_size, + weighting_scheme, + sd3_checkbox, + # Flux.1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + flux1_clip_l, + flux1_t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + guidance_scale, + blockwise_fused_optimizers, + flux_fused_backward_pass, + cpu_offload_checkpointing, + blocks_to_swap, + single_blocks_to_swap, + double_blocks_to_swap, + mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -509,61 +664,50 @@ def train_model( log.info(f"Validating lr scheduler arguments...") if not validate_args_setting(lr_scheduler_args): return - + log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): return TRAIN_BUTTON_VISIBLE # # Validate paths - # - + # + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(log_tracker_config): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + logging_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + output_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(reg_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(resume): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(train_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(vae): return TRAIN_BUTTON_VISIBLE + # # End of path validation # - # This function validates files or folder paths. Simply add new variables containing file of folder path - # to validate below - # if not validate_paths( - # dataset_config=dataset_config, - # headless=headless, - # log_tracker_config=log_tracker_config, - # logging_dir=logging_dir, - # output_dir=output_dir, - # pretrained_model_name_or_path=pretrained_model_name_or_path, - # reg_data_dir=reg_data_dir, - # resume=resume, - # train_data_dir=train_data_dir, - # vae=vae, - # ): - # return TRAIN_BUTTON_VISIBLE - if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless=headless ): @@ -573,15 +717,6 @@ def train_model( log.info( "Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..." ) - if max_train_steps > 0: - if lr_warmup != 0: - lr_warmup_steps = round( - float(int(lr_warmup) * int(max_train_steps) / 100) - ) - else: - lr_warmup_steps = 0 - else: - lr_warmup_steps = 0 if max_train_steps == 0: max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required." @@ -640,11 +775,11 @@ def train_model( reg_factor = 1 else: log.warning( - "Regularisation images are used... Will double the number of steps required..." + "Regularization images are used... Will double the number of steps required..." ) reg_factor = 2 - log.info(f"Regulatization factor: {reg_factor}") + log.info(f"Regularization factor: {reg_factor}") if max_train_steps == 0: # calculate max_train_steps @@ -664,13 +799,18 @@ def train_model( else: max_train_steps_info = f"Max train steps: {max_train_steps}" - if lr_warmup != 0: - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - else: - lr_warmup_steps = 0 - log.info(f"Total steps: {total_steps}") + # Calculate lr_warmup_steps + if lr_warmup_steps > 0: + lr_warmup_steps = int(lr_warmup_steps) + if lr_warmup > 0: + log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.") + elif lr_warmup != 0: + lr_warmup_steps = lr_warmup / 100 + else: + lr_warmup_steps = 0 + log.info(f"Train batch size: {train_batch_size}") log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}") log.info(f"Epoch: {epoch}") @@ -682,7 +822,7 @@ def train_model( log.error("accelerate not found") return TRAIN_BUTTON_VISIBLE - run_cmd = [rf'{accelerate_path}', "launch"] + run_cmd = [rf"{accelerate_path}", "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -701,10 +841,23 @@ def train_model( ) if sdxl: - run_cmd.append(rf'{scriptdir}/sd-scripts/sdxl_train.py') + run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py") + elif sd3_checkbox: + run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py") + elif flux1_checkbox: + run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py") else: run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py") + cache_text_encoder_outputs = ( + (sdxl and sdxl_cache_text_encoder_outputs) + or (sd3_checkbox and sd3_cache_text_encoder_outputs) + or (flux1_checkbox and flux1_cache_text_encoder_outputs) + ) + cache_text_encoder_outputs_to_disk = ( + sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk + ) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk) + no_half_vae = sdxl and sdxl_no_half_vae if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 else: @@ -715,6 +868,11 @@ def train_model( else: max_train_steps = int(max_train_steps) + if sdxl: + train_text_encoder = (learning_rate_te1 != None and learning_rate_te1 > 0) or ( + learning_rate_te2 != None and learning_rate_te2 > 0 + ) + # def save_huggingface_to_toml(self, toml_file_path: str): config_toml_data = { # Update the values in the TOML data @@ -724,19 +882,28 @@ def train_model( "bucket_reso_steps": bucket_reso_steps, "cache_latents": cache_latents, "cache_latents_to_disk": cache_latents_to_disk, + "cache_text_encoder_outputs": cache_text_encoder_outputs, + "cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk, "caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs), "caption_dropout_rate": caption_dropout_rate, "caption_extension": caption_extension, + "clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None, "clip_skip": clip_skip if clip_skip != 0 else None, "color_aug": color_aug, "dataset_config": dataset_config, "debiased_estimation_loss": debiased_estimation_loss, + "disable_mmap_load_safetensors": disable_mmap_load_safetensors, "dynamo_backend": dynamo_backend, "enable_bucket": enable_bucket, "epoch": int(epoch), "flip_aug": flip_aug, + "fp8_base": fp8_base, "full_bf16": full_bf16, "full_fp16": full_fp16, + "fused_backward_pass": fused_backward_pass if not flux1_checkbox else flux_fused_backward_pass, + "fused_optimizer_groups": ( + int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None + ), "gradient_accumulation_steps": int(gradient_accumulation_steps), "gradient_checkpointing": gradient_checkpointing, "huber_c": huber_c, @@ -750,16 +917,11 @@ def train_model( "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": int(keep_tokens), "learning_rate": learning_rate, # both for sd1.5 and sdxl - "learning_rate_te": ( - learning_rate_te if not sdxl and not 0 else None - ), # only for sd1.5 and not 0 - "learning_rate_te1": ( - learning_rate_te1 if sdxl and not 0 else None - ), # only for sdxl and not 0 - "learning_rate_te2": ( - learning_rate_te2 if sdxl and not 0 else None - ), # only for sdxl and not 0 + "learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5 + "learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl + "learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl "logging_dir": logging_dir, + "log_config": log_config, "log_tracker_config": log_tracker_config, "log_tracker_name": log_tracker_name, "log_with": log_with, @@ -767,15 +929,20 @@ def train_model( "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), "lr_scheduler_num_cycles": ( - int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch) + int(lr_scheduler_num_cycles) + if lr_scheduler_num_cycles != "" + else int(epoch) ), "lr_scheduler_power": lr_scheduler_power, + "lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None, "lr_warmup_steps": lr_warmup_steps, "masked_loss": masked_loss, "max_bucket_reso": max_bucket_reso, "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), - "max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None, + "max_train_epochs": ( + int(max_train_epochs) if int(max_train_epochs) != 0 else None + ), "max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None, "mem_eff_attn": mem_eff_attn, "metadata_author": metadata_author, @@ -789,6 +956,7 @@ def train_model( "mixed_precision": mixed_precision, "multires_noise_discount": multires_noise_discount, "multires_noise_iterations": multires_noise_iterations if not 0 else None, + "no_half_vae": no_half_vae, "no_token_padding": no_token_padding, "noise_offset": noise_offset if not 0 else None, "noise_offset_random_strength": noise_offset_random_strength, @@ -825,6 +993,10 @@ def train_model( "save_last_n_steps_state": ( save_last_n_steps_state if save_last_n_steps_state != 0 else None ), + "save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None, + "save_last_n_epochs_state": ( + save_last_n_epochs_state if save_last_n_epochs_state != 0 else None + ), "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, @@ -834,20 +1006,65 @@ def train_model( "sdpa": True if xformers == "sdpa" else None, "seed": int(seed) if int(seed) != 0 else None, "shuffle_caption": shuffle_caption, + "skip_cache_check": skip_cache_check, "stop_text_encoder_training": ( stop_text_encoder_training if stop_text_encoder_training != 0 else None ), + "t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None, "train_batch_size": train_batch_size, "train_data_dir": train_data_dir, + "train_text_encoder": train_text_encoder if sdxl else None, "v2": v2, "v_parameterization": v_parameterization, "v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None, "vae": vae, "vae_batch_size": vae_batch_size if vae_batch_size != 0 else None, "wandb_api_key": wandb_api_key, - "wandb_run_name": wandb_run_name, + "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, + # SD3 only Parameters + # "cache_text_encoder_outputs": see previous assignment above for code + # "cache_text_encoder_outputs_to_disk": see previous assignment above for code + "clip_g": clip_g if sd3_checkbox else None, + # "clip_l": see previous assignment above for code + "logit_mean": logit_mean if sd3_checkbox else None, + "logit_std": logit_std if sd3_checkbox else None, + "mode_scale": mode_scale if sd3_checkbox else None, + "save_clip": save_clip if sd3_checkbox else None, + "save_t5xxl": save_t5xxl if sd3_checkbox else None, + # "t5xxl": see previous assignment above for code + "t5xxl_device": t5xxl_device if sd3_checkbox else None, + "t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None, + "text_encoder_batch_size": ( + sd3_text_encoder_batch_size if sd3_checkbox else None + ), + "weighting_scheme": weighting_scheme if sd3_checkbox else None, + # Flux.1 specific parameters + # "cache_text_encoder_outputs": see previous assignment above for code + # "cache_text_encoder_outputs_to_disk": see previous assignment above for code + "ae": ae if flux1_checkbox else None, + # "clip_l": see previous assignment above for code + # "t5xxl": see previous assignment above for code + "discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None, + "model_prediction_type": model_prediction_type if flux1_checkbox else None, + "timestep_sampling": timestep_sampling if flux1_checkbox else None, + "split_mode": split_mode if flux1_checkbox else None, + "train_blocks": train_blocks if flux1_checkbox else None, + "t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None, + "guidance_scale": guidance_scale if flux1_checkbox else None, + "blockwise_fused_optimizers": ( + blockwise_fused_optimizers if flux1_checkbox else None + ), + # "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code + "cpu_offload_checkpointing": ( + cpu_offload_checkpointing if flux1_checkbox else None + ), + "blocks_to_swap": blocks_to_swap if flux1_checkbox else None, + "single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None, + "double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None, + "mem_eff_save": mem_eff_save if flux1_checkbox else None, + "apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None, } # Given dictionary `config_toml_data` @@ -855,7 +1072,7 @@ def train_model( config_toml_data = { key: value for key, value in config_toml_data.items() - if value not in ["", False, None] + if not any([value == "", value is False, value is None]) } config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers) @@ -865,8 +1082,8 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = fr"{output_dir}/config_dreambooth-{formatted_datetime}.toml" - + tmpfilename = rf"{output_dir}/config_dreambooth-{formatted_datetime}.toml" + # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: toml.dump(config_toml_data, toml_file) @@ -875,7 +1092,7 @@ def train_model( log.error(f"Failed to write TOML file: {toml_file.name}") run_cmd.append(f"--config_file") - run_cmd.append(rf'{tmpfilename}') + run_cmd.append(rf"{tmpfilename}") # Initialize a dictionary with always-included keyword arguments kwargs_for_training = { @@ -981,6 +1198,26 @@ def dreambooth_tab( config=config, ) + # Add SDXL Parameters + sdxl_params = SDXLParameters( + source_model.sdxl_checkbox, + config=config, + trainer="finetune", + ) + + # Add FLUX1 Parameters + flux1_training = flux1Training( + headless=headless, + config=config, + flux1_checkbox=source_model.flux1_checkbox, + finetuning=True, + ) + + # Add SD3 Parameters + sd3_training = sd3Training( + headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox + ) + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless, config=config) advanced_training.color_aug.change( @@ -1011,6 +1248,7 @@ def dreambooth_tab( source_model.v2, source_model.v_parameterization, source_model.sdxl_checkbox, + source_model.flux1_checkbox, folders.logging_dir, source_model.train_data_dir, folders.reg_data_dir, @@ -1023,6 +1261,7 @@ def dreambooth_tab( basic_training.learning_rate_te2, basic_training.lr_scheduler, basic_training.lr_warmup, + basic_training.lr_warmup_steps, basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, @@ -1035,6 +1274,7 @@ def dreambooth_tab( basic_training.caption_extension, basic_training.enable_bucket, advanced_training.gradient_checkpointing, + advanced_training.fp8_base, advanced_training.full_fp16, advanced_training.full_bf16, advanced_training.no_token_padding, @@ -1084,6 +1324,7 @@ def dreambooth_tab( basic_training.optimizer, basic_training.optimizer_args, basic_training.lr_scheduler_args, + basic_training.lr_scheduler_type, advanced_training.noise_offset_type, advanced_training.noise_offset, advanced_training.noise_offset_random_strength, @@ -1106,12 +1347,21 @@ def dreambooth_tab( advanced_training.save_every_n_steps, advanced_training.save_last_n_steps, advanced_training.save_last_n_steps_state, + advanced_training.save_last_n_epochs, + advanced_training.save_last_n_epochs_state, + advanced_training.skip_cache_check, advanced_training.log_with, advanced_training.wandb_api_key, advanced_training.wandb_run_name, advanced_training.log_tracker_name, advanced_training.log_tracker_config, + advanced_training.log_config, advanced_training.scale_v_pred_loss_like_noise_pred, + sdxl_params.disable_mmap_load_safetensors, + sdxl_params.fused_backward_pass, + sdxl_params.fused_optimizer_groups, + sdxl_params.sdxl_cache_text_encoder_outputs, + sdxl_params.sdxl_no_half_vae, advanced_training.min_timestep, advanced_training.max_timestep, advanced_training.debiased_estimation_loss, @@ -1128,6 +1378,43 @@ def dreambooth_tab( metadata.metadata_license, metadata.metadata_tags, metadata.metadata_title, + # SD3 Parameters + sd3_training.sd3_cache_text_encoder_outputs, + sd3_training.sd3_cache_text_encoder_outputs_to_disk, + sd3_training.clip_g, + sd3_training.clip_l, + sd3_training.logit_mean, + sd3_training.logit_std, + sd3_training.mode_scale, + sd3_training.save_clip, + sd3_training.save_t5xxl, + sd3_training.t5xxl, + sd3_training.t5xxl_device, + sd3_training.t5xxl_dtype, + sd3_training.sd3_text_encoder_batch_size, + sd3_training.weighting_scheme, + source_model.sd3_checkbox, + # Flux1 parameters + flux1_training.flux1_cache_text_encoder_outputs, + flux1_training.flux1_cache_text_encoder_outputs_to_disk, + flux1_training.ae, + flux1_training.clip_l, + flux1_training.t5xxl, + flux1_training.discrete_flow_shift, + flux1_training.model_prediction_type, + flux1_training.timestep_sampling, + flux1_training.split_mode, + flux1_training.train_blocks, + flux1_training.t5xxl_max_token_length, + flux1_training.guidance_scale, + flux1_training.blockwise_fused_optimizers, + flux1_training.flux_fused_backward_pass, + flux1_training.cpu_offload_checkpointing, + flux1_training.blocks_to_swap, + flux1_training.single_blocks_to_swap, + flux1_training.double_blocks_to_swap, + flux1_training.mem_eff_save, + flux1_training.apply_t5_attn_mask, ] configuration.button_open_config.click( diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index 62b12fd9f..f1650e7f6 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -12,6 +12,7 @@ ) from .custom_logging import setup_logging +from .sd_modeltype import SDModelType # Set up logging log = setup_logging() @@ -337,6 +338,19 @@ def change_sdxl(sdxl): outputs=[load_tuned_model_to, load_original_model_to], ) + #secondary event on model_tuned for auto-detection of v2/SDXL + def change_modeltype_model_tuned(path): + detect = SDModelType(path) + v2 = gr.Checkbox(value=detect.Is_SD2()) + sdxl = gr.Checkbox(value=detect.Is_SDXL()) + return v2, sdxl + + model_tuned.change( + change_modeltype_model_tuned, + inputs=model_tuned, + outputs=[v2, sdxl] + ) + extract_button = gr.Button("Extract LoRA model") extract_button.click( diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index c84922c0e..77351bbf9 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -18,14 +18,18 @@ SaveConfigFile, scriptdir, update_my_data, - validate_file_path, validate_folder_path, validate_model_path, - validate_args_setting, setup_environment, + validate_file_path, + validate_folder_path, + validate_model_path, + validate_args_setting, + setup_environment, ) from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel from .class_basic_training import BasicTraining from .class_advanced_training import AdvancedTraining +from .class_sd3 import sd3Training from .class_folders import Folders from .class_sdxl_parameters import SDXLParameters from .class_command_executor import CommandExecutor @@ -34,6 +38,7 @@ from .class_huggingface import HuggingFace from .class_metadata import MetaData from .class_gui_config import KohyaSSGUIConfig +from .class_flux1 import flux1Training from .custom_logging import setup_logging @@ -65,6 +70,7 @@ def save_configuration( v2, v_parameterization, sdxl_checkbox, + flux1_checkbox, train_dir, image_folder, output_dir, @@ -82,6 +88,7 @@ def save_configuration( learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, dataset_repeats, train_batch_size, epoch, @@ -116,6 +123,7 @@ def save_configuration( save_state_on_train_end, resume, gradient_checkpointing, + fp8_base, gradient_accumulation_steps, block_lr, mem_eff_attn, @@ -142,6 +150,7 @@ def save_configuration( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -164,12 +173,19 @@ def save_configuration( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, + fused_backward_pass, + fused_optimizer_groups, sdxl_cache_text_encoder_outputs, sdxl_no_half_vae, min_timestep, @@ -188,6 +204,43 @@ def save_configuration( metadata_license, metadata_tags, metadata_title, + # SD3 parameters + sd3_cache_text_encoder_outputs, + sd3_cache_text_encoder_outputs_to_disk, + clip_g, + clip_l, + logit_mean, + logit_std, + mode_scale, + save_clip, + save_t5xxl, + t5xxl, + t5xxl_device, + t5xxl_dtype, + sd3_text_encoder_batch_size, + weighting_scheme, + sd3_checkbox, + # Flux.1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + flux1_clip_l, + flux1_t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + guidance_scale, + blockwise_fused_optimizers, + flux_fused_backward_pass, + cpu_offload_checkpointing, + blocks_to_swap, + single_blocks_to_swap, + double_blocks_to_swap, + mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -231,6 +284,7 @@ def open_configuration( v2, v_parameterization, sdxl_checkbox, + flux1_checkbox, train_dir, image_folder, output_dir, @@ -248,6 +302,7 @@ def open_configuration( learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, dataset_repeats, train_batch_size, epoch, @@ -282,6 +337,7 @@ def open_configuration( save_state_on_train_end, resume, gradient_checkpointing, + fp8_base, gradient_accumulation_steps, block_lr, mem_eff_attn, @@ -308,6 +364,7 @@ def open_configuration( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -330,12 +387,19 @@ def open_configuration( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, + fused_backward_pass, + fused_optimizer_groups, sdxl_cache_text_encoder_outputs, sdxl_no_half_vae, min_timestep, @@ -354,6 +418,43 @@ def open_configuration( metadata_license, metadata_tags, metadata_title, + # SD3 parameters + sd3_cache_text_encoder_outputs, + sd3_cache_text_encoder_outputs_to_disk, + clip_g, + clip_l, + logit_mean, + logit_std, + mode_scale, + save_clip, + save_t5xxl, + t5xxl, + t5xxl_device, + t5xxl_dtype, + sd3_text_encoder_batch_size, + weighting_scheme, + sd3_checkbox, + # Flux.1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + flux1_clip_l, + flux1_t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + guidance_scale, + blockwise_fused_optimizers, + flux_fused_backward_pass, + cpu_offload_checkpointing, + blocks_to_swap, + single_blocks_to_swap, + double_blocks_to_swap, + mem_eff_save, + apply_t5_attn_mask, training_preset, ): # Get list of function parameters and values @@ -403,6 +504,7 @@ def train_model( v2, v_parameterization, sdxl_checkbox, + flux1_checkbox, train_dir, image_folder, output_dir, @@ -420,6 +522,7 @@ def train_model( learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, dataset_repeats, train_batch_size, epoch, @@ -454,6 +557,7 @@ def train_model( save_state_on_train_end, resume, gradient_checkpointing, + fp8_base, gradient_accumulation_steps, block_lr, mem_eff_attn, @@ -480,6 +584,7 @@ def train_model( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -502,12 +607,19 @@ def train_model( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, + fused_backward_pass, + fused_optimizer_groups, sdxl_cache_text_encoder_outputs, sdxl_no_half_vae, min_timestep, @@ -526,6 +638,43 @@ def train_model( metadata_license, metadata_tags, metadata_title, + # SD3 parameters + sd3_cache_text_encoder_outputs, + sd3_cache_text_encoder_outputs_to_disk, + clip_g, + clip_l, + logit_mean, + logit_std, + mode_scale, + save_clip, + save_t5xxl, + t5xxl, + t5xxl_device, + t5xxl_dtype, + sd3_text_encoder_batch_size, + weighting_scheme, + sd3_checkbox, + # Flux.1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + flux1_clip_l, + flux1_t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + guidance_scale, + blockwise_fused_optimizers, + flux_fused_backward_pass, + cpu_offload_checkpointing, + blocks_to_swap, + single_blocks_to_swap, + double_blocks_to_swap, + mem_eff_save, + apply_t5_attn_mask, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -558,44 +707,36 @@ def train_model( # # Validate paths - # - + # + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(image_folder): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(log_tracker_config): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + logging_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + output_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(resume): return TRAIN_BUTTON_VISIBLE - + # # End of path validation # - - # if not validate_paths( - # dataset_config=dataset_config, - # finetune_image_folder=image_folder, - # headless=headless, - # log_tracker_config=log_tracker_config, - # logging_dir=logging_dir, - # output_dir=output_dir, - # pretrained_model_name_or_path=pretrained_model_name_or_path, - # resume=resume, - # ): - # return TRAIN_BUTTON_VISIBLE if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless @@ -727,10 +868,16 @@ def train_model( log.info(max_train_steps_info) - if max_train_steps != 0: - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + # Calculate lr_warmup_steps + if lr_warmup_steps > 0: + lr_warmup_steps = int(lr_warmup_steps) + if lr_warmup > 0: + log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.") + elif lr_warmup != 0: + lr_warmup_steps = lr_warmup / 100 else: lr_warmup_steps = 0 + log.info(f"lr_warmup_steps = {lr_warmup_steps}") accelerate_path = get_executable_path("accelerate") @@ -738,7 +885,7 @@ def train_model( log.error("accelerate not found") return TRAIN_BUTTON_VISIBLE - run_cmd = [rf'{accelerate_path}', "launch"] + run_cmd = [rf"{accelerate_path}", "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -758,6 +905,10 @@ def train_model( if sdxl_checkbox: run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py") + elif sd3_checkbox: + run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py") + elif flux1_checkbox: + run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py") else: run_cmd.append(rf"{scriptdir}/sd-scripts/fine_tune.py") @@ -766,7 +917,14 @@ def train_model( if use_latent_files == "Yes" else f"{train_dir}/{caption_metadata_filename}" ) - cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs + cache_text_encoder_outputs = ( + (sdxl_checkbox and sdxl_cache_text_encoder_outputs) + or (sd3_checkbox and sd3_cache_text_encoder_outputs) + or (flux1_checkbox and flux1_cache_text_encoder_outputs) + ) + cache_text_encoder_outputs_to_disk = ( + sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk + ) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk) no_half_vae = sdxl_checkbox and sdxl_no_half_vae if max_data_loader_n_workers == "" or None: @@ -791,19 +949,27 @@ def train_model( "cache_latents": cache_latents, "cache_latents_to_disk": cache_latents_to_disk, "cache_text_encoder_outputs": cache_text_encoder_outputs, + "cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk, "caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs), "caption_dropout_rate": caption_dropout_rate, "caption_extension": caption_extension, + "clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None, "clip_skip": clip_skip if clip_skip != 0 else None, "color_aug": color_aug, "dataset_config": dataset_config, "dataset_repeats": int(dataset_repeats), "debiased_estimation_loss": debiased_estimation_loss, + "disable_mmap_load_safetensors": disable_mmap_load_safetensors, "dynamo_backend": dynamo_backend, "enable_bucket": True, "flip_aug": flip_aug, + "fp8_base": fp8_base, "full_bf16": full_bf16, "full_fp16": full_fp16, + "fused_backward_pass": fused_backward_pass if not flux1_checkbox else flux_fused_backward_pass, + "fused_optimizer_groups": ( + int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None + ), "gradient_accumulation_steps": int(gradient_accumulation_steps), "gradient_checkpointing": gradient_checkpointing, "huber_c": huber_c, @@ -828,11 +994,13 @@ def train_model( learning_rate_te2 if sdxl_checkbox else None ), # only for sdxl "logging_dir": logging_dir, + "log_config": log_config, "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, "loss_type": loss_type, "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), + "lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None, "lr_warmup_steps": lr_warmup_steps, "masked_loss": masked_loss, "max_bucket_reso": int(max_bucket_reso), @@ -886,6 +1054,10 @@ def train_model( "save_last_n_steps_state": ( save_last_n_steps_state if save_last_n_steps_state != 0 else None ), + "save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None, + "save_last_n_epochs_state": ( + save_last_n_epochs_state if save_last_n_epochs_state != 0 else None + ), "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, @@ -895,6 +1067,8 @@ def train_model( "sdpa": True if xformers == "sdpa" else None, "seed": int(seed) if int(seed) != 0 else None, "shuffle_caption": shuffle_caption, + "skip_cache_check": skip_cache_check, + "t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None, "train_batch_size": train_batch_size, "train_data_dir": image_folder, "train_text_encoder": train_text_encoder, @@ -904,9 +1078,51 @@ def train_model( "v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None, "vae_batch_size": vae_batch_size if vae_batch_size != 0 else None, "wandb_api_key": wandb_api_key, - "wandb_run_name": wandb_run_name, + "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, + # SD3 only Parameters + # "cache_text_encoder_outputs": see previous assignment above for code + # "cache_text_encoder_outputs_to_disk": see previous assignment above for code + "clip_g": clip_g if sd3_checkbox else None, + # "clip_l": see previous assignment above for code + "logit_mean": logit_mean if sd3_checkbox else None, + "logit_std": logit_std if sd3_checkbox else None, + "mode_scale": mode_scale if sd3_checkbox else None, + "save_clip": save_clip if sd3_checkbox else None, + "save_t5xxl": save_t5xxl if sd3_checkbox else None, + # "t5xxl": see previous assignment above for code + "t5xxl_device": t5xxl_device if sd3_checkbox else None, + "t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None, + "text_encoder_batch_size": ( + sd3_text_encoder_batch_size if sd3_checkbox else None + ), + "weighting_scheme": weighting_scheme if sd3_checkbox else None, + # Flux.1 specific parameters + # "cache_text_encoder_outputs": see previous assignment above for code + # "cache_text_encoder_outputs_to_disk": see previous assignment above for code + "ae": ae if flux1_checkbox else None, + # "clip_l": see previous assignment above for code + # "t5xxl": see previous assignment above for code + "discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None, + "model_prediction_type": model_prediction_type if flux1_checkbox else None, + "timestep_sampling": timestep_sampling if flux1_checkbox else None, + "split_mode": split_mode if flux1_checkbox else None, + "train_blocks": train_blocks if flux1_checkbox else None, + "t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None, + "guidance_scale": guidance_scale if flux1_checkbox else None, + "blockwise_fused_optimizers": ( + blockwise_fused_optimizers if flux1_checkbox else None + ), + # "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code + "cpu_offload_checkpointing": ( + cpu_offload_checkpointing if flux1_checkbox else None + ), + "blocks_to_swap": blocks_to_swap if flux1_checkbox else None, + "single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None, + "double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None, + "mem_eff_save": mem_eff_save if flux1_checkbox else None, + "apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None, } # Given dictionary `config_toml_data` @@ -924,7 +1140,7 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = fr"{output_dir}/config_finetune-{formatted_datetime}.toml" + tmpfilename = rf"{output_dir}/config_finetune-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: toml.dump(config_toml_data, toml_file) @@ -1090,7 +1306,9 @@ def list_presets(path): # Add SDXL Parameters sdxl_params = SDXLParameters( - source_model.sdxl_checkbox, config=config + source_model.sdxl_checkbox, + config=config, + trainer="finetune", ) with gr.Row(): @@ -1099,6 +1317,19 @@ def list_presets(path): label="Train text encoder", value=True ) + # Add FLUX1 Parameters + flux1_training = flux1Training( + headless=headless, + config=config, + flux1_checkbox=source_model.flux1_checkbox, + finetuning=True, + ) + + # Add SD3 Parameters + sd3_training = sd3Training( + headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox + ) + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): with gr.Row(): gradient_accumulation_steps = gr.Slider( @@ -1146,6 +1377,7 @@ def list_presets(path): source_model.v2, source_model.v_parameterization, source_model.sdxl_checkbox, + source_model.flux1_checkbox, train_dir, image_folder, output_dir, @@ -1163,6 +1395,7 @@ def list_presets(path): basic_training.learning_rate, basic_training.lr_scheduler, basic_training.lr_warmup, + basic_training.lr_warmup_steps, dataset_repeats, basic_training.train_batch_size, basic_training.epoch, @@ -1196,6 +1429,7 @@ def list_presets(path): advanced_training.save_state_on_train_end, advanced_training.resume, advanced_training.gradient_checkpointing, + advanced_training.fp8_base, gradient_accumulation_steps, block_lr, advanced_training.mem_eff_attn, @@ -1222,6 +1456,7 @@ def list_presets(path): basic_training.optimizer, basic_training.optimizer_args, basic_training.lr_scheduler_args, + basic_training.lr_scheduler_type, advanced_training.noise_offset_type, advanced_training.noise_offset, advanced_training.noise_offset_random_strength, @@ -1244,12 +1479,19 @@ def list_presets(path): advanced_training.save_every_n_steps, advanced_training.save_last_n_steps, advanced_training.save_last_n_steps_state, + advanced_training.save_last_n_epochs, + advanced_training.save_last_n_epochs_state, + advanced_training.skip_cache_check, advanced_training.log_with, advanced_training.wandb_api_key, advanced_training.wandb_run_name, advanced_training.log_tracker_name, advanced_training.log_tracker_config, + advanced_training.log_config, advanced_training.scale_v_pred_loss_like_noise_pred, + sdxl_params.disable_mmap_load_safetensors, + sdxl_params.fused_backward_pass, + sdxl_params.fused_optimizer_groups, sdxl_params.sdxl_cache_text_encoder_outputs, sdxl_params.sdxl_no_half_vae, advanced_training.min_timestep, @@ -1268,6 +1510,43 @@ def list_presets(path): metadata.metadata_license, metadata.metadata_tags, metadata.metadata_title, + # SD3 Parameters + sd3_training.sd3_cache_text_encoder_outputs, + sd3_training.sd3_cache_text_encoder_outputs_to_disk, + sd3_training.clip_g, + sd3_training.clip_l, + sd3_training.logit_mean, + sd3_training.logit_std, + sd3_training.mode_scale, + sd3_training.save_clip, + sd3_training.save_t5xxl, + sd3_training.t5xxl, + sd3_training.t5xxl_device, + sd3_training.t5xxl_dtype, + sd3_training.sd3_text_encoder_batch_size, + sd3_training.weighting_scheme, + source_model.sd3_checkbox, + # Flux1 parameters + flux1_training.flux1_cache_text_encoder_outputs, + flux1_training.flux1_cache_text_encoder_outputs_to_disk, + flux1_training.ae, + flux1_training.clip_l, + flux1_training.t5xxl, + flux1_training.discrete_flow_shift, + flux1_training.model_prediction_type, + flux1_training.timestep_sampling, + flux1_training.split_mode, + flux1_training.train_blocks, + flux1_training.t5xxl_max_token_length, + flux1_training.guidance_scale, + flux1_training.blockwise_fused_optimizers, + flux1_training.flux_fused_backward_pass, + flux1_training.cpu_offload_checkpointing, + flux1_training.blocks_to_swap, + flux1_training.single_blocks_to_swap, + flux1_training.double_blocks_to_swap, + flux1_training.mem_eff_save, + flux1_training.apply_t5_attn_mask, ] configuration.button_open_config.click( diff --git a/kohya_gui/flux_extract_lora_gui.py b/kohya_gui/flux_extract_lora_gui.py new file mode 100644 index 000000000..1fbd2756f --- /dev/null +++ b/kohya_gui/flux_extract_lora_gui.py @@ -0,0 +1,273 @@ +import gradio as gr +import subprocess +import os +import sys +from .common_gui import ( + get_saveasfilename_path, + get_file_path, + scriptdir, + list_files, + create_refresh_button, + setup_environment, +) +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() + +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 + +PYTHON = sys.executable + + +def extract_flux_lora( + model_org, + model_tuned, + save_to, + save_precision, + dim, + device, + clamp_quantile, + no_metadata, + mem_eff_safe_open, +): + # Check for required inputs + if model_org == "" or model_tuned == "" or save_to == "": + log.info( + "Please provide all required inputs: original model, tuned model, and save path." + ) + return + + # Check if source models exist + if not os.path.isfile(model_org): + log.info("The provided original model is not a file") + return + + if not os.path.isfile(model_tuned): + log.info("The provided tuned model is not a file") + return + + # Prepare save path + if os.path.dirname(save_to) == "": + save_to = os.path.join(os.path.dirname(model_tuned), save_to) + if os.path.isdir(save_to): + save_to = os.path.join(save_to, "flux_lora.safetensors") + if os.path.normpath(model_tuned) == os.path.normpath(save_to): + path, ext = os.path.splitext(save_to) + save_to = f"{path}_lora{ext}" + + run_cmd = [ + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/networks/flux_extract_lora.py", + "--model_org", + rf"{model_org}", + "--model_tuned", + rf"{model_tuned}", + "--save_to", + rf"{save_to}", + "--dim", + str(dim), + "--device", + device, + "--clamp_quantile", + str(clamp_quantile), + ] + + if save_precision: + run_cmd.extend(["--save_precision", save_precision]) + + if no_metadata: + run_cmd.append("--no_metadata") + + if mem_eff_safe_open: + run_cmd.append("--mem_eff_safe_open") + + env = setup_environment() + + # Reconstruct the safe command string for display + command_to_run = " ".join(run_cmd) + log.info(f"Executing command: {command_to_run}") + + # Run the command + subprocess.run(run_cmd, env=env) + + +def gradio_flux_extract_lora_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + + def list_models(path): + return list(list_files(path, exts=[".safetensors"], all=True)) + + with gr.Tab("Extract Flux LoRA"): + gr.Markdown( + "This utility can extract a LoRA network from a finetuned Flux model." + ) + + lora_ext = gr.Textbox(value="*.safetensors", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) + model_ext = gr.Textbox(value="*.safetensors", visible=False) + model_ext_name = gr.Textbox(value="Model types", visible=False) + + with gr.Group(), gr.Row(): + model_org = gr.Dropdown( + label="Original Flux model (path to the original model)", + interactive=True, + choices=[""] + list_models(current_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + model_org, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) + button_model_org_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_model_org_file.click( + get_file_path, + inputs=[model_org, model_ext, model_ext_name], + outputs=model_org, + show_progress=False, + ) + + model_tuned = gr.Dropdown( + label="Finetuned Flux model (path to the finetuned model to extract)", + interactive=True, + choices=[""] + list_models(current_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + model_tuned, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) + button_model_tuned_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_model_tuned_file.click( + get_file_path, + inputs=[model_tuned, model_ext, model_ext_name], + outputs=model_tuned, + show_progress=False, + ) + + with gr.Group(), gr.Row(): + save_to = gr.Dropdown( + label="Save to (path where to save the extracted LoRA model...)", + interactive=True, + choices=[""] + list_models(current_save_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_models(current_save_dir)}, + "open_folder_small", + ) + button_save_to = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + + save_precision = gr.Dropdown( + label="Save precision", + choices=["None", "float", "fp16", "bf16"], + value="None", + interactive=True, + ) + + with gr.Row(): + dim = gr.Slider( + minimum=1, + maximum=1024, + label="Network Dimension (Rank)", + value=4, + step=1, + interactive=True, + ) + device = gr.Dropdown( + label="Device", + choices=["cpu", "cuda"], + value="cuda", + interactive=True, + ) + clamp_quantile = gr.Slider( + minimum=0, + maximum=1, + label="Clamp Quantile", + value=0.99, + step=0.01, + interactive=True, + ) + + with gr.Row(): + no_metadata = gr.Checkbox( + label="No metadata (do not save sai modelspec metadata)", + value=False, + interactive=True, + ) + mem_eff_safe_open = gr.Checkbox( + label="Memory efficient safe open (experimental feature)", + value=False, + interactive=True, + ) + + extract_button = gr.Button("Extract Flux LoRA model") + + extract_button.click( + extract_flux_lora, + inputs=[ + model_org, + model_tuned, + save_to, + save_precision, + dim, + device, + clamp_quantile, + no_metadata, + mem_eff_safe_open, + ], + show_progress=False, + ) + + model_org.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)), + inputs=model_org, + outputs=model_org, + show_progress=False, + ) + model_tuned.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)), + inputs=model_tuned, + outputs=model_tuned, + show_progress=False, + ) + save_to.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) diff --git a/kohya_gui/flux_merge_lora_gui.py b/kohya_gui/flux_merge_lora_gui.py new file mode 100644 index 000000000..c303087bf --- /dev/null +++ b/kohya_gui/flux_merge_lora_gui.py @@ -0,0 +1,470 @@ +# Standard library imports +import os +import subprocess +import sys +import json + +# Third-party imports +import gradio as gr + +# Local module imports +from .common_gui import ( + get_saveasfilename_path, + get_file_path, + scriptdir, + list_files, + create_refresh_button, + setup_environment, +) +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() + +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 + +PYTHON = sys.executable + + +def check_model(model): + if not model: + return True + if not os.path.isfile(model): + log.info(f"The provided {model} is not a file") + return False + return True + + +def verify_conditions(flux_model, lora_models): + lora_models_count = sum(1 for model in lora_models if model) + if flux_model and lora_models_count >= 1: + return True + elif not flux_model and lora_models_count >= 2: + return True + return False + + +class GradioFluxMergeLoRaTab: + def __init__(self, headless=False): + self.headless = headless + self.build_tab() + + def save_inputs_to_json(self, file_path, inputs): + with open(file_path, "w", encoding="utf-8") as file: + json.dump(inputs, file) + log.info(f"Saved inputs to {file_path}") + + def load_inputs_from_json(self, file_path): + with open(file_path, "r", encoding="utf-8") as file: + inputs = json.load(file) + log.info(f"Loaded inputs from {file_path}") + return inputs + + def build_tab(self): + current_flux_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + current_lora_model_dir = current_flux_model_dir + + def list_flux_models(path): + nonlocal current_flux_model_dir + current_flux_model_dir = path + return list(list_files(path, exts=[".safetensors"], all=True)) + + def list_lora_models(path): + nonlocal current_lora_model_dir + current_lora_model_dir = path + return list(list_files(path, exts=[".safetensors"], all=True)) + + def list_save_to(path): + nonlocal current_save_dir + current_save_dir = path + return list(list_files(path, exts=[".safetensors"], all=True)) + + with gr.Tab("Merge FLUX LoRA"): + gr.Markdown( + "This utility can merge up to 4 LoRA into a FLUX model or alternatively merge up to 4 LoRA together." + ) + + lora_ext = gr.Textbox(value="*.safetensors", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) + flux_ext = gr.Textbox(value="*.safetensors", visible=False) + flux_ext_name = gr.Textbox(value="FLUX model types", visible=False) + + with gr.Group(), gr.Row(): + flux_model = gr.Dropdown( + label="FLUX Model (Optional. FLUX model path, if you want to merge it with LoRA files via the 'concat' method)", + interactive=True, + choices=[""] + list_flux_models(current_flux_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + flux_model, + lambda: None, + lambda: {"choices": list_flux_models(current_flux_model_dir)}, + "open_folder_small", + ) + flux_model_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + flux_model_file.click( + get_file_path, + inputs=[flux_model, flux_ext, flux_ext_name], + outputs=flux_model, + show_progress=False, + ) + + flux_model.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_flux_models(path)), + inputs=flux_model, + outputs=flux_model, + show_progress=False, + ) + + with gr.Group(), gr.Row(): + lora_a_model = gr.Dropdown( + label='LoRA model "A" (path to the LoRA A model)', + interactive=True, + choices=[""] + list_lora_models(current_lora_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + lora_a_model, + lambda: None, + lambda: {"choices": list_lora_models(current_lora_model_dir)}, + "open_folder_small", + ) + button_lora_a_model_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + button_lora_a_model_file.click( + get_file_path, + inputs=[lora_a_model, lora_ext, lora_ext_name], + outputs=lora_a_model, + show_progress=False, + ) + + lora_b_model = gr.Dropdown( + label='LoRA model "B" (path to the LoRA B model)', + interactive=True, + choices=[""] + list_lora_models(current_lora_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + lora_b_model, + lambda: None, + lambda: {"choices": list_lora_models(current_lora_model_dir)}, + "open_folder_small", + ) + button_lora_b_model_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + button_lora_b_model_file.click( + get_file_path, + inputs=[lora_b_model, lora_ext, lora_ext_name], + outputs=lora_b_model, + show_progress=False, + ) + + lora_a_model.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)), + inputs=lora_a_model, + outputs=lora_a_model, + show_progress=False, + ) + lora_b_model.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)), + inputs=lora_b_model, + outputs=lora_b_model, + show_progress=False, + ) + + with gr.Row(): + ratio_a = gr.Slider( + label="Model A merge ratio (eg: 0.5 mean 50%)", + minimum=0, + maximum=2, + step=0.01, + value=0.0, + interactive=True, + ) + + ratio_b = gr.Slider( + label="Model B merge ratio (eg: 0.5 mean 50%)", + minimum=0, + maximum=2, + step=0.01, + value=0.0, + interactive=True, + ) + + with gr.Group(), gr.Row(): + lora_c_model = gr.Dropdown( + label='LoRA model "C" (path to the LoRA C model)', + interactive=True, + choices=[""] + list_lora_models(current_lora_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + lora_c_model, + lambda: None, + lambda: {"choices": list_lora_models(current_lora_model_dir)}, + "open_folder_small", + ) + button_lora_c_model_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + button_lora_c_model_file.click( + get_file_path, + inputs=[lora_c_model, lora_ext, lora_ext_name], + outputs=lora_c_model, + show_progress=False, + ) + + lora_d_model = gr.Dropdown( + label='LoRA model "D" (path to the LoRA D model)', + interactive=True, + choices=[""] + list_lora_models(current_lora_model_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + lora_d_model, + lambda: None, + lambda: {"choices": list_lora_models(current_lora_model_dir)}, + "open_folder_small", + ) + button_lora_d_model_file = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + button_lora_d_model_file.click( + get_file_path, + inputs=[lora_d_model, lora_ext, lora_ext_name], + outputs=lora_d_model, + show_progress=False, + ) + lora_c_model.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)), + inputs=lora_c_model, + outputs=lora_c_model, + show_progress=False, + ) + lora_d_model.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)), + inputs=lora_d_model, + outputs=lora_d_model, + show_progress=False, + ) + + with gr.Row(): + ratio_c = gr.Slider( + label="Model C merge ratio (eg: 0.5 mean 50%)", + minimum=0, + maximum=2, + step=0.01, + value=0.0, + interactive=True, + ) + + ratio_d = gr.Slider( + label="Model D merge ratio (eg: 0.5 mean 50%)", + minimum=0, + maximum=2, + step=0.01, + value=0.0, + interactive=True, + ) + + with gr.Group(), gr.Row(): + save_to = gr.Dropdown( + label="Save to (path for the file to save...)", + interactive=True, + choices=[""] + list_save_to(current_save_dir), + value="", + allow_custom_value=True, + ) + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) + button_save_to = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + button_save_to.click( + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, + show_progress=False, + ) + precision = gr.Radio( + label="Merge precision", + choices=["float", "fp16", "bf16"], + value="float", + interactive=True, + ) + save_precision = gr.Radio( + label="Save precision", + choices=["float", "fp16", "bf16", "fp8"], + value="fp16", + interactive=True, + ) + + save_to.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_save_to(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) + + with gr.Row(): + loading_device = gr.Dropdown( + label="Loading device", + choices=["cpu", "cuda"], + value="cpu", + interactive=True, + ) + working_device = gr.Dropdown( + label="Working device", + choices=["cpu", "cuda"], + value="cpu", + interactive=True, + ) + + with gr.Row(): + concat = gr.Checkbox(label="Concat LoRA", value=False) + shuffle = gr.Checkbox(label="Shuffle LoRA weights", value=False) + no_metadata = gr.Checkbox(label="Don't save metadata", value=False) + diffusers = gr.Checkbox(label="Diffusers LoRA", value=False) + + merge_button = gr.Button("Merge model") + + merge_button.click( + self.merge_flux_lora, + inputs=[ + flux_model, + lora_a_model, + lora_b_model, + lora_c_model, + lora_d_model, + ratio_a, + ratio_b, + ratio_c, + ratio_d, + save_to, + precision, + save_precision, + loading_device, + working_device, + concat, + shuffle, + no_metadata, + diffusers, + ], + show_progress=False, + ) + + def merge_flux_lora( + self, + flux_model, + lora_a_model, + lora_b_model, + lora_c_model, + lora_d_model, + ratio_a, + ratio_b, + ratio_c, + ratio_d, + save_to, + precision, + save_precision, + loading_device, + working_device, + concat, + shuffle, + no_metadata, + difffusers, + ): + log.info("Merge FLUX LoRA...") + models = [ + lora_a_model, + lora_b_model, + lora_c_model, + lora_d_model, + ] + lora_models = [model for model in models if model] + ratios = [ratio for model, ratio in zip(models, [ratio_a, ratio_b, ratio_c, ratio_d]) if model] + + # if not verify_conditions(flux_model, lora_models): + # log.info( + # "Warning: Either provide at least one LoRA model along with the FLUX model or at least two LoRA models if no FLUX model is provided." + # ) + # return + + for model in [flux_model] + lora_models: + if not check_model(model): + return + + run_cmd = [rf"{PYTHON}", rf"{scriptdir}/sd-scripts/networks/flux_merge_lora.py"] + + if flux_model: + run_cmd.extend(["--flux_model", rf"{flux_model}"]) + + run_cmd.extend([ + "--save_precision", save_precision, + "--precision", precision, + "--save_to", rf"{save_to}", + "--loading_device", loading_device, + "--working_device", working_device, + ]) + + if lora_models: + run_cmd.append("--models") + run_cmd.extend(lora_models) + run_cmd.append("--ratios") + run_cmd.extend(map(str, ratios)) + + if concat: + run_cmd.append("--concat") + if shuffle: + run_cmd.append("--shuffle") + if no_metadata: + run_cmd.append("--no_metadata") + if difffusers: + run_cmd.append("--diffusers") + + env = setup_environment() + + # Reconstruct the safe command string for display + command_to_run = " ".join(run_cmd) + log.info(f"Executing command: {command_to_run}") + + # Run the command in the sd-scripts folder context + subprocess.run(run_cmd, env=env) + + log.info("Done merging...") \ No newline at end of file diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 1132fde1d..68eb67808 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -19,8 +19,12 @@ SaveConfigFile, scriptdir, update_my_data, - validate_file_path, validate_folder_path, validate_model_path, validate_toml_file, - validate_args_setting, setup_environment, + validate_file_path, + validate_folder_path, + validate_model_path, + validate_toml_file, + validate_args_setting, + setup_environment, ) from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile @@ -36,6 +40,7 @@ from .class_huggingface import HuggingFace from .class_metadata import MetaData from .class_gui_config import KohyaSSGUIConfig +from .class_flux1 import flux1Training from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -73,72 +78,89 @@ def save_configuration( save_as_bool, file_path, + + # source model section pretrained_model_name_or_path, v2, v_parameterization, sdxl, - logging_dir, + flux1_checkbox, + dataset_config, + save_model_as, + save_precision, train_data_dir, + output_name, + model_list, + training_comment, + + # folders section + logging_dir, reg_data_dir, output_dir, - dataset_config, + + # basic training section max_resolution, learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, - mixed_precision, - save_precision, seed, - num_cpu_threads_per_process, cache_latents, cache_latents_to_disk, caption_extension, enable_bucket, - gradient_checkpointing, - fp8_base, - full_fp16, - # no_token_padding, stop_text_encoder_training, min_bucket_reso, max_bucket_reso, - # use_8bit_adam, + max_train_epochs, + max_train_steps, + lr_scheduler_num_cycles, + lr_scheduler_power, + optimizer, + optimizer_args, + lr_scheduler_args, + lr_scheduler_type, + max_grad_norm, + + # accelerate launch section + mixed_precision, + num_cpu_threads_per_process, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + main_process_port, + dynamo_backend, + dynamo_mode, + dynamo_use_fullgraph, + dynamo_use_dynamic, + extra_accelerate_launch_args, + + ### advanced training section + gradient_checkpointing, + fp8_base, + fp8_base_unet, + full_fp16, + highvram, + lowvram, xformers, - save_model_as, shuffle_caption, save_state, save_state_on_train_end, resume, prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - network_weights, - dim_from_weights, color_aug, flip_aug, masked_loss, clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - main_process_port, gradient_accumulation_steps, mem_eff_attn, - output_name, - model_list, max_token_length, - max_train_epochs, - max_train_steps, max_data_loader_n_workers, - network_alpha, - training_comment, keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -146,10 +168,6 @@ def save_configuration( v_pred_like_loss, caption_dropout_every_n_epochs, caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - max_grad_norm, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -158,6 +176,44 @@ def save_configuration( multires_noise_discount, ip_noise_gamma, ip_noise_gamma_random_strength, + additional_parameters, + loss_type, + huber_schedule, + huber_c, + vae_batch_size, + min_snr_gamma, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, + log_with, + wandb_api_key, + wandb_run_name, + log_tracker_name, + log_tracker_config, + log_config, + scale_v_pred_loss_like_noise_pred, + full_bf16, + min_timestep, + max_timestep, + vae, + weighted_captions, + debiased_estimation_loss, + + # sdxl parameters section + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + + ### + text_encoder_lr, + t5xxl_lr, + unet_lr, + network_dim, + network_weights, + dim_from_weights, + network_alpha, LoRA_type, factor, bypass_mode, @@ -177,12 +233,6 @@ def save_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters, - loss_type, - huber_schedule, - huber_c, - vae_batch_size, - min_snr_gamma, down_lr_weight, mid_lr_weight, up_lr_weight, @@ -191,34 +241,17 @@ def save_configuration( block_alphas, conv_block_dims, conv_block_alphas, - weighted_captions, unit, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - log_with, - wandb_api_key, - wandb_run_name, - log_tracker_name, - log_tracker_config, - scale_v_pred_loss_like_noise_pred, scale_weight_norms, network_dropout, rank_dropout, module_dropout, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - full_bf16, - min_timestep, - max_timestep, - vae, - dynamo_backend, - dynamo_mode, - dynamo_use_fullgraph, - dynamo_use_dynamic, - extra_accelerate_launch_args, LyCORIS_preset, - debiased_estimation_loss, + loraplus_lr_ratio, + loraplus_text_encoder_lr_ratio, + loraplus_unet_lr_ratio, + + # huggingface section huggingface_repo_id, huggingface_token, huggingface_repo_type, @@ -227,11 +260,44 @@ def save_configuration( save_state_to_huggingface, resume_from_huggingface, async_upload, + + # metadata section metadata_author, metadata_description, metadata_license, metadata_tags, metadata_title, + + # Flux1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + clip_l, + t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + enable_all_linear, + guidance_scale, + mem_eff_save, + apply_t5_attn_mask, + split_qkv, + train_t5xxl, + cpu_offload_checkpointing, + img_attn_dim, + img_mlp_dim, + img_mod_dim, + single_dim, + txt_attn_dim, + txt_mlp_dim, + txt_mod_dim, + single_mod_dim, + in_dims, + train_double_block_indices, + train_single_block_indices, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -278,72 +344,89 @@ def open_configuration( ask_for_file, apply_preset, file_path, + + # source model section pretrained_model_name_or_path, v2, v_parameterization, sdxl, - logging_dir, + flux1_checkbox, + dataset_config, + save_model_as, + save_precision, train_data_dir, + output_name, + model_list, + training_comment, + + # folders section + logging_dir, reg_data_dir, output_dir, - dataset_config, + + # basic training section max_resolution, learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, - mixed_precision, - save_precision, seed, - num_cpu_threads_per_process, cache_latents, cache_latents_to_disk, caption_extension, enable_bucket, - gradient_checkpointing, - fp8_base, - full_fp16, - # no_token_padding, stop_text_encoder_training, min_bucket_reso, max_bucket_reso, - # use_8bit_adam, + max_train_epochs, + max_train_steps, + lr_scheduler_num_cycles, + lr_scheduler_power, + optimizer, + optimizer_args, + lr_scheduler_args, + lr_scheduler_type, + max_grad_norm, + + # accelerate launch section + mixed_precision, + num_cpu_threads_per_process, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + main_process_port, + dynamo_backend, + dynamo_mode, + dynamo_use_fullgraph, + dynamo_use_dynamic, + extra_accelerate_launch_args, + + ### advanced training section + gradient_checkpointing, + fp8_base, + fp8_base_unet, + full_fp16, + highvram, + lowvram, xformers, - save_model_as, shuffle_caption, save_state, save_state_on_train_end, resume, prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - network_weights, - dim_from_weights, color_aug, flip_aug, masked_loss, clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - main_process_port, gradient_accumulation_steps, mem_eff_attn, - output_name, - model_list, max_token_length, - max_train_epochs, - max_train_steps, max_data_loader_n_workers, - network_alpha, - training_comment, keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -351,10 +434,6 @@ def open_configuration( v_pred_like_loss, caption_dropout_every_n_epochs, caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - max_grad_norm, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -363,6 +442,44 @@ def open_configuration( multires_noise_discount, ip_noise_gamma, ip_noise_gamma_random_strength, + additional_parameters, + loss_type, + huber_schedule, + huber_c, + vae_batch_size, + min_snr_gamma, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, + log_with, + wandb_api_key, + wandb_run_name, + log_tracker_name, + log_tracker_config, + log_config, + scale_v_pred_loss_like_noise_pred, + full_bf16, + min_timestep, + max_timestep, + vae, + weighted_captions, + debiased_estimation_loss, + + # sdxl parameters section + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + + ### + text_encoder_lr, + t5xxl_lr, + unet_lr, + network_dim, + network_weights, + dim_from_weights, + network_alpha, LoRA_type, factor, bypass_mode, @@ -382,12 +499,6 @@ def open_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters, - loss_type, - huber_schedule, - huber_c, - vae_batch_size, - min_snr_gamma, down_lr_weight, mid_lr_weight, up_lr_weight, @@ -396,34 +507,17 @@ def open_configuration( block_alphas, conv_block_dims, conv_block_alphas, - weighted_captions, unit, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - log_with, - wandb_api_key, - wandb_run_name, - log_tracker_name, - log_tracker_config, - scale_v_pred_loss_like_noise_pred, scale_weight_norms, network_dropout, rank_dropout, module_dropout, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - full_bf16, - min_timestep, - max_timestep, - vae, - dynamo_backend, - dynamo_mode, - dynamo_use_fullgraph, - dynamo_use_dynamic, - extra_accelerate_launch_args, LyCORIS_preset, - debiased_estimation_loss, + loraplus_lr_ratio, + loraplus_text_encoder_lr_ratio, + loraplus_unet_lr_ratio, + + # huggingface section huggingface_repo_id, huggingface_token, huggingface_repo_type, @@ -432,17 +526,52 @@ def open_configuration( save_state_to_huggingface, resume_from_huggingface, async_upload, + + # metadata section metadata_author, metadata_description, metadata_license, metadata_tags, metadata_title, + + # Flux1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + clip_l, + t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + enable_all_linear, + guidance_scale, + mem_eff_save, + apply_t5_attn_mask, + split_qkv, + train_t5xxl, + cpu_offload_checkpointing, + img_attn_dim, + img_mlp_dim, + img_mod_dim, + single_dim, + txt_attn_dim, + txt_mlp_dim, + txt_mod_dim, + single_mod_dim, + in_dims, + train_double_block_indices, + train_single_block_indices, + + ## training_preset, ): - # Get list of function parameters and values + # Get list of function parameters and their values parameters = list(locals().items()) - # Determines if a preset configuration is being applied + # Determine if a preset configuration is being applied if apply_preset: if training_preset != "none": log.info(f"Applying preset {training_preset}...") @@ -492,6 +621,8 @@ def open_configuration( # Display LoCon parameters based on the 'LoRA_type' from the loaded data # This section dynamically adjusts visibility of certain parameters in the UI if my_data.get("LoRA_type", "Standard") in { + "Flux1", + "Flux1 OFT", "LoCon", "Kohya DyLoRA", "Kohya LoCon", @@ -513,72 +644,89 @@ def open_configuration( def train_model( headless, print_only, + + # source model section pretrained_model_name_or_path, v2, v_parameterization, sdxl, - logging_dir, + flux1_checkbox, + dataset_config, + save_model_as, + save_precision, train_data_dir, + output_name, + model_list, + training_comment, + + # folders section + logging_dir, reg_data_dir, output_dir, - dataset_config, + + # basic training section max_resolution, learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, - mixed_precision, - save_precision, seed, - num_cpu_threads_per_process, cache_latents, cache_latents_to_disk, caption_extension, enable_bucket, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + max_train_epochs, + max_train_steps, + lr_scheduler_num_cycles, + lr_scheduler_power, + optimizer, + optimizer_args, + lr_scheduler_args, + lr_scheduler_type, + max_grad_norm, + + # accelerate launch section + mixed_precision, + num_cpu_threads_per_process, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + main_process_port, + dynamo_backend, + dynamo_mode, + dynamo_use_fullgraph, + dynamo_use_dynamic, + extra_accelerate_launch_args, + + ### advanced training section gradient_checkpointing, fp8_base, + fp8_base_unet, full_fp16, - # no_token_padding, - stop_text_encoder_training_pct, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, + highvram, + lowvram, xformers, - save_model_as, shuffle_caption, save_state, save_state_on_train_end, resume, prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - network_weights, - dim_from_weights, color_aug, flip_aug, masked_loss, clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - main_process_port, gradient_accumulation_steps, mem_eff_attn, - output_name, - model_list, # Keep this. Yes, it is unused here but required given the common list used max_token_length, - max_train_epochs, - max_train_steps, max_data_loader_n_workers, - network_alpha, - training_comment, keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, persistent_data_loader_workers, bucket_no_upscale, random_crop, @@ -586,10 +734,6 @@ def train_model( v_pred_like_loss, caption_dropout_every_n_epochs, caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - max_grad_norm, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -598,6 +742,44 @@ def train_model( multires_noise_discount, ip_noise_gamma, ip_noise_gamma_random_strength, + additional_parameters, + loss_type, + huber_schedule, + huber_c, + vae_batch_size, + min_snr_gamma, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, + log_with, + wandb_api_key, + wandb_run_name, + log_tracker_name, + log_tracker_config, + log_config, + scale_v_pred_loss_like_noise_pred, + full_bf16, + min_timestep, + max_timestep, + vae, + weighted_captions, + debiased_estimation_loss, + + # sdxl parameters section + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + + ### + text_encoder_lr, + t5xxl_lr, + unet_lr, + network_dim, + network_weights, + dim_from_weights, + network_alpha, LoRA_type, factor, bypass_mode, @@ -617,12 +799,6 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters, - loss_type, - huber_schedule, - huber_c, - vae_batch_size, - min_snr_gamma, down_lr_weight, mid_lr_weight, up_lr_weight, @@ -631,34 +807,17 @@ def train_model( block_alphas, conv_block_dims, conv_block_alphas, - weighted_captions, unit, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - log_with, - wandb_api_key, - wandb_run_name, - log_tracker_name, - log_tracker_config, - scale_v_pred_loss_like_noise_pred, scale_weight_norms, network_dropout, rank_dropout, module_dropout, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - full_bf16, - min_timestep, - max_timestep, - vae, - dynamo_backend, - dynamo_mode, - dynamo_use_fullgraph, - dynamo_use_dynamic, - extra_accelerate_launch_args, LyCORIS_preset, - debiased_estimation_loss, + loraplus_lr_ratio, + loraplus_text_encoder_lr_ratio, + loraplus_unet_lr_ratio, + + # huggingface section huggingface_repo_id, huggingface_token, huggingface_repo_type, @@ -667,11 +826,44 @@ def train_model( save_state_to_huggingface, resume_from_huggingface, async_upload, + + # metadata section metadata_author, metadata_description, metadata_license, metadata_tags, metadata_title, + + # Flux1 + flux1_cache_text_encoder_outputs, + flux1_cache_text_encoder_outputs_to_disk, + ae, + clip_l, + t5xxl, + discrete_flow_shift, + model_prediction_type, + timestep_sampling, + split_mode, + train_blocks, + t5xxl_max_token_length, + enable_all_linear, + guidance_scale, + mem_eff_save, + apply_t5_attn_mask, + split_qkv, + train_t5xxl, + cpu_offload_checkpointing, + img_attn_dim, + img_mlp_dim, + img_mod_dim, + single_dim, + txt_attn_dim, + txt_mlp_dim, + txt_mod_dim, + single_mod_dim, + in_dims, + train_double_block_indices, + train_single_block_indices, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -697,63 +889,58 @@ def train_model( if not validate_args_setting(optimizer_args): return TRAIN_BUTTON_VISIBLE + if flux1_checkbox: + log.info(f"Validating lora type is Flux1 if flux1 checkbox is checked...") + if (LoRA_type != "Flux1") and (LoRA_type != "Flux1 OFT") and ("LyCORIS" not in LoRA_type): + log.error("LoRA type must be set to 'Flux1', 'Flux1 OFT' or 'LyCORIS' if Flux1 checkbox is checked.") + return TRAIN_BUTTON_VISIBLE + # # Validate paths - # - + # + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(log_tracker_config): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + logging_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES: if not validate_toml_file(LyCORIS_preset): return TRAIN_BUTTON_VISIBLE - + if not validate_file_path(network_weights): return TRAIN_BUTTON_VISIBLE - - if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True): + + if not validate_folder_path( + output_dir, can_be_written_to=True, create_if_not_exists=True + ): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(reg_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(resume): return TRAIN_BUTTON_VISIBLE - + if not validate_folder_path(train_data_dir): return TRAIN_BUTTON_VISIBLE - + if not validate_model_path(vae): return TRAIN_BUTTON_VISIBLE - + # # End of path validation # - # if not validate_paths( - # dataset_config=dataset_config, - # headless=headless, - # log_tracker_config=log_tracker_config, - # logging_dir=logging_dir, - # network_weights=network_weights, - # output_dir=output_dir, - # pretrained_model_name_or_path=pretrained_model_name_or_path, - # reg_data_dir=reg_data_dir, - # resume=resume, - # train_data_dir=train_data_dir, - # vae=vae, - # ): - # return TRAIN_BUTTON_VISIBLE - if int(bucket_reso_steps) < 1: output_message( msg="Bucket resolution steps need to be greater than 0", @@ -775,12 +962,12 @@ def train_model( if not os.path.exists(output_dir): os.makedirs(output_dir) - if stop_text_encoder_training_pct > 0: + if stop_text_encoder_training > 0: output_message( msg='Output "stop text encoder training" is not yet supported. Ignoring', headless=headless, ) - stop_text_encoder_training_pct = 0 + stop_text_encoder_training = 0 if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless=headless @@ -799,11 +986,11 @@ def train_model( ) if max_train_steps > 0: # calculate stop encoder training - if stop_text_encoder_training_pct == 0: + if stop_text_encoder_training == 0: stop_text_encoder_training = 0 else: stop_text_encoder_training = math.ceil( - float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + float(max_train_steps) / 100 * int(stop_text_encoder_training) ) if lr_warmup != 0: @@ -874,11 +1061,11 @@ def train_model( reg_factor = 1 else: log.warning( - "Regularisation images are used... Will double the number of steps required..." + "Regularization images are used... Will double the number of steps required..." ) reg_factor = 2 - log.info(f"Regulatization factor: {reg_factor}") + log.info(f"Regularization factor: {reg_factor}") if max_train_steps == 0: # calculate max_train_steps @@ -899,19 +1086,22 @@ def train_model( max_train_steps_info = f"Max train steps: {max_train_steps}" # calculate stop encoder training - if stop_text_encoder_training_pct == 0: + if stop_text_encoder_training == 0: stop_text_encoder_training = 0 else: stop_text_encoder_training = math.ceil( - float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + float(max_train_steps) / 100 * int(stop_text_encoder_training) ) - if lr_warmup != 0: - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - else: - lr_warmup_steps = 0 - - log.info(f"Total steps: {total_steps}") + # Calculate lr_warmup_steps + if lr_warmup_steps > 0: + lr_warmup_steps = int(lr_warmup_steps) + if lr_warmup > 0: + log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.") + elif lr_warmup != 0: + lr_warmup_steps = lr_warmup / 100 + else: + lr_warmup_steps = 0 log.info(f"Train batch size: {train_batch_size}") log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}") @@ -925,7 +1115,7 @@ def train_model( log.error("accelerate not found") return TRAIN_BUTTON_VISIBLE - run_cmd = [rf'{accelerate_path}', "launch"] + run_cmd = [rf"{accelerate_path}", "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -945,6 +1135,8 @@ def train_model( if sdxl: run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train_network.py") + elif flux1_checkbox: + run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train_network.py") else: run_cmd.append(rf"{scriptdir}/sd-scripts/train_network.py") @@ -952,11 +1144,11 @@ def train_model( if LoRA_type == "LyCORIS/BOFT": network_module = "lycoris.kohya" - network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout={rank_dropout} rank_dropout_scale={rank_dropout_scale} constrain={constrain} rescaled={rescaled} algo=boft train_norm={train_norm}" + network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} module_dropout={module_dropout} use_tucker={use_tucker} rank_dropout={rank_dropout} rank_dropout_scale={rank_dropout_scale} algo=boft train_norm={train_norm}" if LoRA_type == "LyCORIS/Diag-OFT": network_module = "lycoris.kohya" - network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout={rank_dropout} rank_dropout_scale={rank_dropout_scale} constrain={constrain} rescaled={rescaled} algo=diag-oft train_norm={train_norm}" + network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} module_dropout={module_dropout} use_tucker={use_tucker} rank_dropout={rank_dropout} rank_dropout_scale={rank_dropout_scale} constraint={constrain} rescaled={rescaled} algo=diag-oft train_norm={train_norm}" if LoRA_type == "LyCORIS/DyLoRA": network_module = "lycoris.kohya" @@ -964,7 +1156,7 @@ def train_model( if LoRA_type == "LyCORIS/GLoRA": network_module = "lycoris.kohya" - network_args = f' preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} rank_dropout={rank_dropout} module_dropout={module_dropout} rank_dropout_scale={rank_dropout_scale} algo="glora" train_norm={train_norm}' + network_args = f' preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} module_dropout={module_dropout} rank_dropout_scale={rank_dropout_scale} algo="glora" train_norm={train_norm}' if LoRA_type == "LyCORIS/iA3": network_module = "lycoris.kohya" @@ -972,19 +1164,83 @@ def train_model( if LoRA_type == "LoCon" or LoRA_type == "LyCORIS/LoCon": network_module = "lycoris.kohya" - network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=locon train_norm={train_norm}" + network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=locon train_norm={train_norm}" if LoRA_type == "LyCORIS/LoHa": network_module = "lycoris.kohya" - network_args = f' preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo="loha" train_norm={train_norm}' + network_args = f' preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=loha train_norm={train_norm}' if LoRA_type == "LyCORIS/LoKr": network_module = "lycoris.kohya" - network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} factor={factor} use_cp={use_cp} use_scalar={use_scalar} decompose_both={decompose_both} rank_dropout_scale={rank_dropout_scale} algo=lokr train_norm={train_norm}" + network_args = f" preset={LyCORIS_preset} conv_dim={conv_dim} conv_alpha={conv_alpha} use_tucker={use_tucker} rank_dropout={rank_dropout} bypass_mode={bypass_mode} dora_wd={dora_wd} module_dropout={module_dropout} factor={factor} use_cp={use_cp} use_scalar={use_scalar} decompose_both={decompose_both} rank_dropout_scale={rank_dropout_scale} algo=lokr train_norm={train_norm}" if LoRA_type == "LyCORIS/Native Fine-Tuning": network_module = "lycoris.kohya" - network_args = f" preset={LyCORIS_preset} rank_dropout={rank_dropout} module_dropout={module_dropout} use_tucker={use_tucker} use_scalar={use_scalar} rank_dropout_scale={rank_dropout_scale} algo=full train_norm={train_norm}" + network_args = f" preset={LyCORIS_preset} rank_dropout={rank_dropout} module_dropout={module_dropout} rank_dropout_scale={rank_dropout_scale} algo=full train_norm={train_norm}" + + if LoRA_type == "Flux1": + # Add a list of supported network arguments for Flux1 below when supported + kohya_lora_var_list = [ + "img_attn_dim", + "img_mlp_dim", + "img_mod_dim", + "single_dim", + "txt_attn_dim", + "txt_mlp_dim", + "txt_mod_dim", + "single_mod_dim", + "in_dims", + "train_double_block_indices", + "train_single_block_indices", + ] + network_module = "networks.lora_flux" + kohya_lora_vars = { + key: value + for key, value in vars().items() + if key in kohya_lora_var_list and value + } + if split_mode: + if train_blocks != "single": + log.warning( + f"train_blocks is currently set to '{train_blocks}'. split_mode is enabled, forcing train_blocks to 'single'." + ) + kohya_lora_vars["train_blocks"] = "single" + + if split_qkv: + kohya_lora_vars["split_qkv"] = True + if train_t5xxl: + kohya_lora_vars["train_t5xxl"] = True + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f" {key}={value}" + + if LoRA_type == "Flux1 OFT": + # Add a list of supported network arguments for Flux1 OFT below when supported + kohya_lora_var_list = [ + "enable_all_linear", + ] + network_module = "networks.oft_flux" + kohya_lora_vars = { + key: value + for key, value in vars().items() + if key in kohya_lora_var_list and value + } + # if split_mode: + # if train_blocks != "single": + # log.warning( + # f"train_blocks is currently set to '{train_blocks}'. split_mode is enabled, forcing train_blocks to 'single'." + # ) + # kohya_lora_vars["train_blocks"] = "single" + + # if split_qkv: + # kohya_lora_vars["split_qkv"] = True + # if train_t5xxl: + # kohya_lora_vars["train_t5xxl"] = True + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f" {key}={value}" if LoRA_type in ["Kohya LoCon", "Standard"]: kohya_lora_var_list = [ @@ -1005,7 +1261,9 @@ def train_model( for key, value in vars().items() if key in kohya_lora_var_list and value } - if LoRA_type == "Kohya LoCon": + + # Not sure if Flux1 is Standard... or LoCon style... flip a coin... going for LoCon style... + if LoRA_type in ["Kohya LoCon"]: network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' for key, value in kohya_lora_vars.items(): @@ -1071,6 +1329,20 @@ def train_model( if value: network_args += f" {key}={value}" + # Set the text_encoder_lr to multiple values if both text_encoder_lr and t5xxl_lr are set + if text_encoder_lr == 0 and t5xxl_lr > 0: + log.error("When specifying T5XXL learning rate, text encoder learning rate need to be a value greater than 0.") + return TRAIN_BUTTON_VISIBLE + + text_encoder_lr_list = [] + + if text_encoder_lr > 0 and t5xxl_lr > 0: + # Set the text_encoder_lr to a combination of text_encoder_lr and t5xxl_lr + text_encoder_lr_list = [float(text_encoder_lr), float(t5xxl_lr)] + elif text_encoder_lr > 0: + # Set the text_encoder_lr to text_encoder_lr only + text_encoder_lr_list = [float(text_encoder_lr), float(text_encoder_lr)] + # Convert learning rates to float once and store the result for re-use learning_rate = float(learning_rate) if learning_rate is not None else 0.0 text_encoder_lr_float = ( @@ -1088,9 +1360,14 @@ def train_model( # Flag to train unet only if its learning rate is non-zero and text encoder's is zero. network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0 + if text_encoder_lr_float != 0 or unet_lr_float != 0: + do_not_set_learning_rate = True + config_toml_data = { "adaptive_noise_scale": ( - adaptive_noise_scale if adaptive_noise_scale != 0 else None + adaptive_noise_scale + if (adaptive_noise_scale != 0 and noise_offset_type == "Original") + else None ), "async_upload": async_upload, "bucket_no_upscale": bucket_no_upscale, @@ -1098,7 +1375,10 @@ def train_model( "cache_latents": cache_latents, "cache_latents_to_disk": cache_latents_to_disk, "cache_text_encoder_outputs": ( - True if sdxl and sdxl_cache_text_encoder_outputs else None + True + if (sdxl and sdxl_cache_text_encoder_outputs) + or (flux1_checkbox and flux1_cache_text_encoder_outputs) + else None ), "caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs), "caption_dropout_rate": caption_dropout_rate, @@ -1113,10 +1393,12 @@ def train_model( "epoch": int(epoch), "flip_aug": flip_aug, "fp8_base": fp8_base, + "fp8_base_unet": fp8_base_unet if flux1_checkbox else None, "full_bf16": full_bf16, "full_fp16": full_fp16, "gradient_accumulation_steps": int(gradient_accumulation_steps), "gradient_checkpointing": gradient_checkpointing, + "highvram": highvram, "huber_c": huber_c, "huber_schedule": huber_schedule, "huggingface_repo_id": huggingface_repo_id, @@ -1127,11 +1409,18 @@ def train_model( "ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None, "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": int(keep_tokens), - "learning_rate": learning_rate, + "learning_rate": None if do_not_set_learning_rate else learning_rate, "logging_dir": logging_dir, + "log_config": log_config, "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, + "loraplus_lr_ratio": loraplus_lr_ratio if not 0 else None, + "loraplus_text_encoder_lr_ratio": ( + loraplus_text_encoder_lr_ratio if not 0 else None + ), + "loraplus_unet_lr_ratio": loraplus_unet_lr_ratio if not 0 else None, "loss_type": loss_type, + "lowvram": lowvram, "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), "lr_scheduler_num_cycles": ( @@ -1140,12 +1429,13 @@ def train_model( else int(epoch) ), "lr_scheduler_power": lr_scheduler_power, + "lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None, "lr_warmup_steps": lr_warmup_steps, "masked_loss": masked_loss, "max_bucket_reso": max_bucket_reso, "max_grad_norm": max_grad_norm, "max_timestep": max_timestep if max_timestep != 0 else None, - "max_token_length": int(max_token_length), + "max_token_length": int(max_token_length) if not flux1_checkbox else None, "max_train_epochs": ( int(max_train_epochs) if int(max_train_epochs) != 0 else None ), @@ -1160,9 +1450,13 @@ def train_model( "min_snr_gamma": min_snr_gamma if min_snr_gamma != 0 else None, "min_timestep": min_timestep if min_timestep != 0 else None, "mixed_precision": mixed_precision, - "multires_noise_discount": multires_noise_discount, + "multires_noise_discount": ( + multires_noise_discount if noise_offset_type == "Multires" else None + ), "multires_noise_iterations": ( - multires_noise_iterations if multires_noise_iterations != 0 else None + multires_noise_iterations + if (multires_noise_iterations != 0 and noise_offset_type == "Multires") + else None ), "network_alpha": network_alpha, "network_args": str(network_args).replace('"', "").split(), @@ -1173,11 +1467,21 @@ def train_model( "network_train_text_encoder_only": network_train_text_encoder_only, "network_weights": network_weights, "no_half_vae": True if sdxl and sdxl_no_half_vae else None, - "noise_offset": noise_offset if noise_offset != 0 else None, - "noise_offset_random_strength": noise_offset_random_strength, + "noise_offset": ( + noise_offset + if (noise_offset != 0 and noise_offset_type == "Original") + else None + ), + "noise_offset_random_strength": ( + noise_offset_random_strength if noise_offset_type == "Original" else None + ), "noise_offset_type": noise_offset_type, "optimizer_type": optimizer, - "optimizer_args": str(optimizer_args).replace('"', "").split(), + "optimizer_args": ( + str(optimizer_args).replace('"', "").split() + if optimizer_args != [] + else None + ), "output_dir": output_dir, "output_name": output_name, "persistent_data_loader_workers": int(persistent_data_loader_workers), @@ -1204,6 +1508,10 @@ def train_model( "save_last_n_steps_state": ( save_last_n_steps_state if save_last_n_steps_state != 0 else None ), + "save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None, + "save_last_n_epochs_state": ( + save_last_n_epochs_state if save_last_n_epochs_state != 0 else None + ), "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, @@ -1214,10 +1522,11 @@ def train_model( "sdpa": True if xformers == "sdpa" else None, "seed": int(seed) if int(seed) != 0 else None, "shuffle_caption": shuffle_caption, + "skip_cache_check": skip_cache_check, "stop_text_encoder_training": ( stop_text_encoder_training if stop_text_encoder_training != 0 else None ), - "text_encoder_lr": text_encoder_lr if not 0 else None, + "text_encoder_lr": text_encoder_lr_list if not [] else None, "train_batch_size": train_batch_size, "train_data_dir": train_data_dir, "training_comment": training_comment, @@ -1229,9 +1538,26 @@ def train_model( "vae": vae, "vae_batch_size": vae_batch_size if vae_batch_size != 0 else None, "wandb_api_key": wandb_api_key, - "wandb_run_name": wandb_run_name, + "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, "weighted_captions": weighted_captions, "xformers": True if xformers == "xformers" else None, + # Flux.1 specific parameters + # "cache_text_encoder_outputs": see previous assignment above for code + "cache_text_encoder_outputs_to_disk": ( + flux1_cache_text_encoder_outputs_to_disk if flux1_checkbox else None + ), + "ae": ae if flux1_checkbox else None, + "clip_l": clip_l if flux1_checkbox else None, + "t5xxl": t5xxl if flux1_checkbox else None, + "discrete_flow_shift": float(discrete_flow_shift) if flux1_checkbox else None, + "model_prediction_type": model_prediction_type if flux1_checkbox else None, + "timestep_sampling": timestep_sampling if flux1_checkbox else None, + "split_mode": split_mode if flux1_checkbox else None, + "t5xxl_max_token_length": int(t5xxl_max_token_length) if flux1_checkbox else None, + "guidance_scale": float(guidance_scale) if flux1_checkbox else None, + "mem_eff_save": mem_eff_save if flux1_checkbox else None, + "apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None, + "cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None, } # Given dictionary `config_toml_data` @@ -1249,7 +1575,7 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = fr"{output_dir}/config_lora-{formatted_datetime}.toml" + tmpfilename = rf"{output_dir}/config_lora-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: @@ -1340,12 +1666,12 @@ def lora_tab( config=config, ) + with gr.Accordion("Folders", open=True), gr.Group(): + folders = Folders(headless=headless, config=config) + with gr.Accordion("Metadata", open=False), gr.Group(): metadata = MetaData(config=config) - with gr.Accordion("Folders", open=False), gr.Group(): - folders = Folders(headless=headless, config=config) - with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( "This section provide Dreambooth tools to help setup your dataset..." @@ -1381,246 +1707,279 @@ def list_presets(path): json_files.append(os.path.join("user_presets", preset_name)) return json_files - + training_preset = gr.Dropdown( label="Presets", choices=["none"] + list_presets(rf"{presets_dir}/lora"), - # elem_id="myDropdown", value="none", + elem_classes=["preset_background"], ) - with gr.Accordion("Basic", open="True"): - with gr.Group(elem_id="basic_tab"): - with gr.Row(): - LoRA_type = gr.Dropdown( - label="LoRA type", - choices=[ - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/iA3", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - "Standard", - ], - value="Standard", - ) - LyCORIS_preset = gr.Dropdown( - label="LyCORIS Preset", - choices=LYCORIS_PRESETS_CHOICES, - value="full", - visible=False, - interactive=True, - allow_custom_value=True, - # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" - ) - with gr.Group(): - with gr.Row(): - network_weights = gr.Textbox( - label="Network weights", - placeholder="(Optional)", - info="Path to an existing LoRA network weights to resume training from", - ) - network_weights_file = gr.Button( - document_symbol, - elem_id="open_folder_small", - elem_classes=["tool"], - visible=(not headless), - ) - network_weights_file.click( - get_any_file_path, - inputs=[network_weights], - outputs=network_weights, - show_progress=False, - ) - dim_from_weights = gr.Checkbox( - label="DIM from weights", - value=False, - info="Automatically determine the dim(rank) from the weight file.", - ) - basic_training = BasicTraining( - learning_rate_value=0.0001, - lr_scheduler_value="cosine", - lr_warmup_value=10, - sdxl_checkbox=source_model.sdxl_checkbox, - config=config, + with gr.Accordion("Basic", open="True", elem_classes=["basic_background"]): + with gr.Row(): + LoRA_type = gr.Dropdown( + label="LoRA type", + choices=[ + "Flux1", + "Flux1 OFT", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/iA3", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + "Standard", + ], + value="Standard", ) - - with gr.Row(): - text_encoder_lr = gr.Number( - label="Text Encoder learning rate", - value=0.0001, - info="(Optional)", - minimum=0, - maximum=1, - ) - - unet_lr = gr.Number( - label="Unet learning rate", - value=0.0001, - info="(Optional)", - minimum=0, - maximum=1, - ) - - # Add SDXL Parameters - sdxl_params = SDXLParameters( - source_model.sdxl_checkbox, config=config + LyCORIS_preset = gr.Dropdown( + label="LyCORIS Preset", + choices=LYCORIS_PRESETS_CHOICES, + value="full", + visible=False, + interactive=True, + allow_custom_value=True, + info="Use path_to_config_file.toml to choose config file (for LyCORIS module settings)" ) - - # LyCORIS Specific parameters - with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion: - with gr.Row(): - factor = gr.Slider( - label="LoKr factor", - value=-1, - minimum=-1, - maximum=64, - step=1, - visible=False, - ) - bypass_mode = gr.Checkbox( - value=False, - label="Bypass mode", - info="Designed for bnb 8bit/4bit linear layer. (QLyCORIS)", - visible=False, - ) - dora_wd = gr.Checkbox( - value=False, - label="DoRA Weight Decompose", - info="Enable the DoRA method for these algorithms", - visible=False, - ) - use_cp = gr.Checkbox( - value=False, - label="Use CP decomposition", - info="A two-step approach utilizing tensor decomposition and fine-tuning to accelerate convolution layers in large neural networks, resulting in significant CPU speedups with minor accuracy drops.", - visible=False, - ) - use_tucker = gr.Checkbox( - value=False, - label="Use Tucker decomposition", - info="Efficiently decompose tensor shapes, resulting in a sequence of convolution layers with varying dimensions and Hadamard product implementation through multiplication of two distinct tensors.", - visible=False, - ) - use_scalar = gr.Checkbox( - value=False, - label="Use Scalar", - info="Train an additional scalar in front of the weight difference, use a different weight initialization strategy.", - visible=False, - ) + with gr.Group(): with gr.Row(): - rank_dropout_scale = gr.Checkbox( - value=False, - label="Rank Dropout Scale", - info="Adjusts the scale of the rank dropout to maintain the average dropout rate, ensuring more consistent regularization across different layers.", - visible=False, + network_weights = gr.Textbox( + label="Network weights", + placeholder="(Optional)", + info="Path to an existing LoRA network weights to resume training from", ) - constrain = gr.Number( - value=0.0, - label="Constrain OFT", - info="Limits the norm of the oft_blocks, ensuring that their magnitude does not exceed a specified threshold, thus controlling the extent of the transformation applied.", - visible=False, - ) - rescaled = gr.Checkbox( - value=False, - label="Rescaled OFT", - info="applies an additional scaling factor to the oft_blocks, allowing for further adjustment of their impact on the model's transformations.", - visible=False, + network_weights_file = gr.Button( + document_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) - train_norm = gr.Checkbox( - value=False, - label="Train Norm", - info="Selects trainable layers in a network, but trains normalization layers identically across methods as they lack matrix decomposition.", - visible=False, + network_weights_file.click( + get_any_file_path, + inputs=[network_weights], + outputs=network_weights, + show_progress=False, ) - decompose_both = gr.Checkbox( + dim_from_weights = gr.Checkbox( + label="DIM from weights", value=False, - label="LoKr decompose both", - info="Controls whether both input and output dimensions of the layer's weights are decomposed into smaller matrices for reparameterization.", - visible=False, - ) - train_on_input = gr.Checkbox( - value=True, - label="iA3 train on input", - info="Set if we change the information going into the system (True) or the information coming out of it (False).", - visible=False, + info="Automatically determine the dim(rank) from the weight file.", ) - with gr.Row() as network_row: - network_dim = gr.Slider( - minimum=1, - maximum=512, - label="Network Rank (Dimension)", - value=8, + basic_training = BasicTraining( + learning_rate_value=0.0001, + lr_scheduler_value="cosine", + lr_warmup_value=10, + sdxl_checkbox=source_model.sdxl_checkbox, + config=config, + ) + + with gr.Row(): + text_encoder_lr = gr.Number( + label="Text Encoder learning rate", + value=0, + info="(Optional) Set CLIP-L and T5XXL learning rates.", + minimum=0, + maximum=1, + ) + + t5xxl_lr = gr.Number( + label="T5XXL learning rate", + value=0, + info="(Optional) Override the T5XXL learning rate set by the Text Encoder learning rate if you desire a different one.", + minimum=0, + maximum=1, + ) + + unet_lr = gr.Number( + label="Unet learning rate", + value=0.0001, + info="(Optional)", + minimum=0, + maximum=1, + ) + + with gr.Row() as loraplus: + loraplus_lr_ratio = gr.Number( + label="LoRA+ learning rate ratio", + value=0, + info="(Optional) starting with 16 is suggested", + minimum=0, + maximum=128, + ) + + loraplus_unet_lr_ratio = gr.Number( + label="LoRA+ Unet learning rate ratio", + value=0, + info="(Optional) starting with 16 is suggested", + minimum=0, + maximum=128, + ) + + loraplus_text_encoder_lr_ratio = gr.Number( + label="LoRA+ Text Encoder learning rate ratio", + value=0, + info="(Optional) starting with 16 is suggested", + minimum=0, + maximum=128, + ) + # Add SDXL Parameters + sdxl_params = SDXLParameters( + source_model.sdxl_checkbox, config=config + ) + + # LyCORIS Specific parameters + with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion: + with gr.Row(): + factor = gr.Slider( + label="LoKr factor", + value=-1, + minimum=-1, + maximum=64, step=1, - interactive=True, + visible=False, ) - network_alpha = gr.Slider( - minimum=0.00001, - maximum=1024, - label="Network Alpha", - value=1, - step=0.00001, - interactive=True, - info="alpha for LoRA weight scaling", + bypass_mode = gr.Checkbox( + value=False, + label="Bypass mode", + info="Designed for bnb 8bit/4bit linear layer. (QLyCORIS)", + visible=False, ) - with gr.Row(visible=False) as convolution_row: - # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) - conv_dim = gr.Slider( - minimum=0, - maximum=512, - value=1, - step=1, - label="Convolution Rank (Dimension)", + dora_wd = gr.Checkbox( + value=False, + label="DoRA Weight Decompose", + info="Enable the DoRA method for these algorithms", + visible=False, ) - conv_alpha = gr.Slider( - minimum=0, - maximum=512, - value=1, - step=1, - label="Convolution Alpha", + use_cp = gr.Checkbox( + value=False, + label="Use CP decomposition", + info="A two-step approach utilizing tensor decomposition and fine-tuning to accelerate convolution layers in large neural networks, resulting in significant CPU speedups with minor accuracy drops.", + visible=False, ) - with gr.Row(): - scale_weight_norms = gr.Slider( - label="Scale weight norms", - value=0, - minimum=0, - maximum=10, - step=0.01, - info="Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR #545 on kohya_ss/sd_scripts repo for details. Recommended setting: 1. Higher is weaker, lower is stronger.", - interactive=True, + use_tucker = gr.Checkbox( + value=False, + label="Use Tucker decomposition", + info="Efficiently decompose tensor shapes, resulting in a sequence of convolution layers with varying dimensions and Hadamard product implementation through multiplication of two distinct tensors.", + visible=False, ) - network_dropout = gr.Slider( - label="Network dropout", - value=0, - minimum=0, - maximum=1, - step=0.01, - info="Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5", + use_scalar = gr.Checkbox( + value=False, + label="Use Scalar", + info="Train an additional scalar in front of the weight difference, use a different weight initialization strategy.", + visible=False, ) - rank_dropout = gr.Slider( - label="Rank dropout", - value=0, - minimum=0, - maximum=1, - step=0.01, - info="can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", + with gr.Row(): + rank_dropout_scale = gr.Checkbox( + value=False, + label="Rank Dropout Scale", + info="Adjusts the scale of the rank dropout to maintain the average dropout rate, ensuring more consistent regularization across different layers.", + visible=False, ) - module_dropout = gr.Slider( - label="Module dropout", + constrain = gr.Number( value=0.0, - minimum=0.0, - maximum=1.0, - step=0.01, - info="can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", + label="Constrain OFT", + info="Limits the norm of the oft_blocks, ensuring that their magnitude does not exceed a specified threshold, thus controlling the extent of the transformation applied.", + visible=False, + ) + rescaled = gr.Checkbox( + value=False, + label="Rescaled OFT", + info="applies an additional scaling factor to the oft_blocks, allowing for further adjustment of their impact on the model's transformations.", + visible=False, + ) + train_norm = gr.Checkbox( + value=False, + label="Train Norm", + info="Selects trainable layers in a network, but trains normalization layers identically across methods as they lack matrix decomposition.", + visible=False, + ) + decompose_both = gr.Checkbox( + value=False, + label="LoKr decompose both", + info="Controls whether both input and output dimensions of the layer's weights are decomposed into smaller matrices for reparameterization.", + visible=False, + ) + train_on_input = gr.Checkbox( + value=True, + label="iA3 train on input", + info="Set if we change the information going into the system (True) or the information coming out of it (False).", + visible=False, ) - with gr.Row(visible=False): + with gr.Row() as network_row: + network_dim = gr.Slider( + minimum=1, + maximum=512, + label="Network Rank (Dimension)", + value=8, + step=1, + interactive=True, + ) + network_alpha = gr.Slider( + minimum=0.00001, + maximum=1024, + label="Network Alpha", + value=1, + step=0.00001, + interactive=True, + info="alpha for LoRA weight scaling", + ) + with gr.Row(visible=False) as convolution_row: + # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) + conv_dim = gr.Slider( + minimum=0, + maximum=512, + value=1, + step=1, + label="Convolution Rank (Dimension)", + ) + conv_alpha = gr.Slider( + minimum=0, + maximum=512, + value=1, + step=1, + label="Convolution Alpha", + ) + with gr.Row(): + scale_weight_norms = gr.Slider( + label="Scale weight norms", + value=0, + minimum=0, + maximum=10, + step=0.01, + info="Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR #545 on kohya_ss/sd_scripts repo for details. Recommended setting: 1. Higher is weaker, lower is stronger.", + interactive=True, + ) + network_dropout = gr.Slider( + label="Network dropout", + value=0, + minimum=0, + maximum=1, + step=0.01, + info="Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5", + ) + rank_dropout = gr.Slider( + label="Rank dropout", + value=0, + minimum=0, + maximum=1, + step=0.01, + info="can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", + ) + module_dropout = gr.Slider( + label="Module dropout", + value=0.0, + minimum=0.0, + maximum=1.0, + step=0.01, + info="can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", + ) + with gr.Row(visible=False): unit = gr.Slider( minimum=1, maximum=64, @@ -1644,6 +2003,8 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { + "Flux1", + "Flux1 OFT", "Kohya DyLoRA", "Kohya LoCon", "LoRA-FA", @@ -1682,6 +2043,8 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { + "Flux1", + "Flux1 OFT", "Standard", "Kohya DyLoRA", "Kohya LoCon", @@ -1694,6 +2057,8 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { + "Flux1", + "Flux1 OFT", "Standard", "LoCon", "Kohya DyLoRA", @@ -1714,6 +2079,8 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { + "Flux1", + "Flux1 OFT", "Standard", "LoCon", "Kohya DyLoRA", @@ -1734,6 +2101,8 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { + "Flux1", + "Flux1 OFT", "Standard", "LoCon", "Kohya DyLoRA", @@ -1831,9 +2200,10 @@ def update_LoRA_settings( "LyCORIS/BOFT", "LyCORIS/Diag-OFT", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoCon", "LyCORIS/LoHa", - "LyCORIS/Native Fine-Tuning", + "LyCORIS/LoKr", }, }, }, @@ -1842,12 +2212,9 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", "LyCORIS/LoCon", "LyCORIS/LoHa", "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", }, }, }, @@ -1871,7 +2238,6 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { - "LyCORIS/BOFT", "LyCORIS/Diag-OFT", }, }, @@ -1881,7 +2247,6 @@ def update_LoRA_settings( "update_params": { "visible": LoRA_type in { - "LyCORIS/BOFT", "LyCORIS/Diag-OFT", }, }, @@ -2037,6 +2402,26 @@ def update_LoRA_settings( }, }, }, + "loraplus": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKR", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Native Fine-Tuning", + "Standard", + }, + }, + }, } results = [] @@ -2047,7 +2432,15 @@ def update_LoRA_settings( return tuple(results) - with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): + # Add FLUX1 Parameters to the basic training accordion + flux1_training = flux1Training( + headless=headless, + config=config, + flux1_checkbox=source_model.flux1_checkbox, + ) + + + with gr.Accordion("Advanced", open=False, elem_classes="advanced_background"): # with gr.Accordion('Advanced Configuration', open=False): with gr.Row(visible=True) as kohya_advanced_lora: with gr.Tab(label="Weights"): @@ -2105,11 +2498,11 @@ def update_LoRA_settings( outputs=[basic_training.cache_latents], ) - with gr.Accordion("Samples", open=False, elem_id="samples_tab"): + with gr.Accordion("Samples", open=False, elem_classes="samples_background"): sample = SampleImages(config=config) global huggingface - with gr.Accordion("HuggingFace", open=False): + with gr.Accordion("HuggingFace", open=False, elem_classes="huggingface_background"): huggingface = HuggingFace(config=config) LoRA_type.change( @@ -2147,6 +2540,7 @@ def update_LoRA_settings( LyCORIS_preset, unit, lycoris_accordion, + loraplus, ], ) @@ -2165,67 +2559,75 @@ def update_LoRA_settings( source_model.v2, source_model.v_parameterization, source_model.sdxl_checkbox, - folders.logging_dir, + source_model.flux1_checkbox, + source_model.dataset_config, + source_model.save_model_as, + source_model.save_precision, source_model.train_data_dir, + source_model.output_name, + source_model.model_list, + source_model.training_comment, + folders.logging_dir, folders.reg_data_dir, folders.output_dir, - source_model.dataset_config, basic_training.max_resolution, basic_training.learning_rate, basic_training.lr_scheduler, basic_training.lr_warmup, + basic_training.lr_warmup_steps, basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, - accelerate_launch.mixed_precision, - source_model.save_precision, basic_training.seed, - accelerate_launch.num_cpu_threads_per_process, basic_training.cache_latents, basic_training.cache_latents_to_disk, basic_training.caption_extension, basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.fp8_base, - advanced_training.full_fp16, - # advanced_training.no_token_padding, basic_training.stop_text_encoder_training, basic_training.min_bucket_reso, basic_training.max_bucket_reso, + basic_training.max_train_epochs, + basic_training.max_train_steps, + basic_training.lr_scheduler_num_cycles, + basic_training.lr_scheduler_power, + basic_training.optimizer, + basic_training.optimizer_args, + basic_training.lr_scheduler_args, + basic_training.lr_scheduler_type, + basic_training.max_grad_norm, + accelerate_launch.mixed_precision, + accelerate_launch.num_cpu_threads_per_process, + accelerate_launch.num_processes, + accelerate_launch.num_machines, + accelerate_launch.multi_gpu, + accelerate_launch.gpu_ids, + accelerate_launch.main_process_port, + accelerate_launch.dynamo_backend, + accelerate_launch.dynamo_mode, + accelerate_launch.dynamo_use_fullgraph, + accelerate_launch.dynamo_use_dynamic, + accelerate_launch.extra_accelerate_launch_args, + advanced_training.gradient_checkpointing, + advanced_training.fp8_base, + advanced_training.fp8_base_unet, + advanced_training.full_fp16, + advanced_training.highvram, + advanced_training.lowvram, advanced_training.xformers, - source_model.save_model_as, advanced_training.shuffle_caption, advanced_training.save_state, advanced_training.save_state_on_train_end, advanced_training.resume, advanced_training.prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - network_weights, - dim_from_weights, advanced_training.color_aug, advanced_training.flip_aug, advanced_training.masked_loss, advanced_training.clip_skip, - accelerate_launch.num_processes, - accelerate_launch.num_machines, - accelerate_launch.multi_gpu, - accelerate_launch.gpu_ids, - accelerate_launch.main_process_port, advanced_training.gradient_accumulation_steps, advanced_training.mem_eff_attn, - source_model.output_name, - source_model.model_list, advanced_training.max_token_length, - basic_training.max_train_epochs, - basic_training.max_train_steps, advanced_training.max_data_loader_n_workers, - network_alpha, - source_model.training_comment, advanced_training.keep_tokens, - basic_training.lr_scheduler_num_cycles, - basic_training.lr_scheduler_power, advanced_training.persistent_data_loader_workers, advanced_training.bucket_no_upscale, advanced_training.random_crop, @@ -2233,10 +2635,6 @@ def update_LoRA_settings( advanced_training.v_pred_like_loss, advanced_training.caption_dropout_every_n_epochs, advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - basic_training.lr_scheduler_args, - basic_training.max_grad_norm, advanced_training.noise_offset_type, advanced_training.noise_offset, advanced_training.noise_offset_random_strength, @@ -2245,6 +2643,40 @@ def update_LoRA_settings( advanced_training.multires_noise_discount, advanced_training.ip_noise_gamma, advanced_training.ip_noise_gamma_random_strength, + advanced_training.additional_parameters, + advanced_training.loss_type, + advanced_training.huber_schedule, + advanced_training.huber_c, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.save_last_n_epochs, + advanced_training.save_last_n_epochs_state, + advanced_training.skip_cache_check, + advanced_training.log_with, + advanced_training.wandb_api_key, + advanced_training.wandb_run_name, + advanced_training.log_tracker_name, + advanced_training.log_tracker_config, + advanced_training.log_config, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.full_bf16, + advanced_training.min_timestep, + advanced_training.max_timestep, + advanced_training.vae, + advanced_training.weighted_captions, + advanced_training.debiased_estimation_loss, + sdxl_params.sdxl_cache_text_encoder_outputs, + sdxl_params.sdxl_no_half_vae, + text_encoder_lr, + t5xxl_lr, + unet_lr, + network_dim, + network_weights, + dim_from_weights, + network_alpha, LoRA_type, factor, bypass_mode, @@ -2264,12 +2696,6 @@ def update_LoRA_settings( sample.sample_every_n_epochs, sample.sample_sampler, sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.loss_type, - advanced_training.huber_schedule, - advanced_training.huber_c, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, down_lr_weight, mid_lr_weight, up_lr_weight, @@ -2278,34 +2704,15 @@ def update_LoRA_settings( block_alphas, conv_block_dims, conv_block_alphas, - advanced_training.weighted_captions, unit, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.log_with, - advanced_training.wandb_api_key, - advanced_training.wandb_run_name, - advanced_training.log_tracker_name, - advanced_training.log_tracker_config, - advanced_training.scale_v_pred_loss_like_noise_pred, scale_weight_norms, network_dropout, rank_dropout, module_dropout, - sdxl_params.sdxl_cache_text_encoder_outputs, - sdxl_params.sdxl_no_half_vae, - advanced_training.full_bf16, - advanced_training.min_timestep, - advanced_training.max_timestep, - advanced_training.vae, - accelerate_launch.dynamo_backend, - accelerate_launch.dynamo_mode, - accelerate_launch.dynamo_use_fullgraph, - accelerate_launch.dynamo_use_dynamic, - accelerate_launch.extra_accelerate_launch_args, LyCORIS_preset, - advanced_training.debiased_estimation_loss, + loraplus_lr_ratio, + loraplus_text_encoder_lr_ratio, + loraplus_unet_lr_ratio, huggingface.huggingface_repo_id, huggingface.huggingface_token, huggingface.huggingface_repo_type, @@ -2319,6 +2726,36 @@ def update_LoRA_settings( metadata.metadata_license, metadata.metadata_tags, metadata.metadata_title, + # Flux1 parameters + flux1_training.flux1_cache_text_encoder_outputs, + flux1_training.flux1_cache_text_encoder_outputs_to_disk, + flux1_training.ae, + flux1_training.clip_l, + flux1_training.t5xxl, + flux1_training.discrete_flow_shift, + flux1_training.model_prediction_type, + flux1_training.timestep_sampling, + flux1_training.split_mode, + flux1_training.train_blocks, + flux1_training.t5xxl_max_token_length, + flux1_training.enable_all_linear, + flux1_training.guidance_scale, + flux1_training.mem_eff_save, + flux1_training.apply_t5_attn_mask, + flux1_training.split_qkv, + flux1_training.train_t5xxl, + flux1_training.cpu_offload_checkpointing, + flux1_training.img_attn_dim, + flux1_training.img_mlp_dim, + flux1_training.img_mod_dim, + flux1_training.single_dim, + flux1_training.txt_attn_dim, + flux1_training.txt_mlp_dim, + flux1_training.txt_mod_dim, + flux1_training.single_mod_dim, + flux1_training.in_dims, + flux1_training.train_double_block_indices, + flux1_training.train_single_block_indices, ] configuration.button_open_config.click( diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index a3337c4cf..72e632124 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -16,6 +16,7 @@ create_refresh_button, setup_environment ) from .custom_logging import setup_logging +from .sd_modeltype import SDModelType # Set up logging log = setup_logging() @@ -145,6 +146,13 @@ def list_save_to(path): show_progress=False, ) + #secondary event on sd_model for auto-detection of SDXL + sd_model.change( + lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()), + inputs=sd_model, + outputs=sdxl_model + ) + with gr.Group(), gr.Row(): lora_a_model = gr.Dropdown( label='LoRA model "A" (path to the LoRA A model)', diff --git a/kohya_gui/sd_modeltype.py b/kohya_gui/sd_modeltype.py new file mode 100755 index 000000000..bb70150a0 --- /dev/null +++ b/kohya_gui/sd_modeltype.py @@ -0,0 +1,63 @@ +from os.path import isfile +from safetensors import safe_open +import enum + +# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403 + + +class ModelType(enum.Enum): + UNKNOWN = 0 + SD1 = 1 + SD2 = 2 + SDXL = 3 + SD3 = 4 + FLUX1 = 5 + + +class SDModelType: + def __init__(self, safetensors_path): + self.model_type = ModelType.UNKNOWN + + if not isfile(safetensors_path): + return + + try: + st = safe_open(filename=safetensors_path, framework="numpy", device="cpu") + + # print(st.keys()) + + def hasKeyPrefix(pfx): + return any(k.startswith(pfx) for k in st.keys()) + + if "model.diffusion_model.x_embedder.proj.weight" in st.keys(): + self.model_type = ModelType.SD3 + elif ( + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" + in st.keys() + or "double_blocks.0.img_attn.norm.key_norm.scale" in st.keys() + ): + # print("flux1 model detected...") + self.model_type = ModelType.FLUX1 + elif hasKeyPrefix("conditioner."): + self.model_type = ModelType.SDXL + elif hasKeyPrefix("cond_stage_model.model."): + self.model_type = ModelType.SD2 + elif hasKeyPrefix("model."): + self.model_type = ModelType.SD1 + except: + pass + + def Is_SD1(self): + return self.model_type == ModelType.SD1 + + def Is_SD2(self): + return self.model_type == ModelType.SD2 + + def Is_SDXL(self): + return self.model_type == ModelType.SDXL + + def Is_SD3(self): + return self.model_type == ModelType.SD3 + + def Is_FLUX1(self): + return self.model_type == ModelType.FLUX1 diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index e85b47fd0..42249aee4 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -70,6 +70,7 @@ def save_configuration( learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, @@ -135,6 +136,7 @@ def save_configuration( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -156,12 +158,17 @@ def save_configuration( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, min_timestep, max_timestep, sdxl_no_half_vae, @@ -229,6 +236,7 @@ def open_configuration( learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, @@ -294,6 +302,7 @@ def open_configuration( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -315,12 +324,17 @@ def open_configuration( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, min_timestep, max_timestep, sdxl_no_half_vae, @@ -381,6 +395,7 @@ def train_model( learning_rate, lr_scheduler, lr_warmup, + lr_warmup_steps, train_batch_size, epoch, save_every_n_epochs, @@ -446,6 +461,7 @@ def train_model( optimizer, optimizer_args, lr_scheduler_args, + lr_scheduler_type, noise_offset_type, noise_offset, noise_offset_random_strength, @@ -467,12 +483,17 @@ def train_model( save_every_n_steps, save_last_n_steps, save_last_n_steps_state, + save_last_n_epochs, + save_last_n_epochs_state, + skip_cache_check, log_with, wandb_api_key, wandb_run_name, log_tracker_name, log_tracker_config, + log_config, scale_v_pred_loss_like_noise_pred, + disable_mmap_load_safetensors, min_timestep, max_timestep, sdxl_no_half_vae, @@ -549,20 +570,6 @@ def train_model( # End of path validation # - # if not validate_paths( - # dataset_config=dataset_config, - # headless=headless, - # log_tracker_config=log_tracker_config, - # logging_dir=logging_dir, - # output_dir=output_dir, - # pretrained_model_name_or_path=pretrained_model_name_or_path, - # reg_data_dir=reg_data_dir, - # resume=resume, - # train_data_dir=train_data_dir, - # vae=vae, - # ): - # return TRAIN_BUTTON_VISIBLE - if token_string == "": output_message(msg="Token string is missing", headless=headless) return TRAIN_BUTTON_VISIBLE @@ -588,13 +595,6 @@ def train_model( stop_text_encoder_training = math.ceil( float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) ) - - if lr_warmup != 0: - lr_warmup_steps = round( - float(int(lr_warmup) * int(max_train_steps) / 100) - ) - else: - lr_warmup_steps = 0 else: stop_text_encoder_training = 0 lr_warmup_steps = 0 @@ -657,11 +657,11 @@ def train_model( reg_factor = 1 else: log.warning( - "Regularisation images are used... Will double the number of steps required..." + "Regularization images are used... Will double the number of steps required..." ) reg_factor = 2 - log.info(f"Regulatization factor: {reg_factor}") + log.info(f"Regularization factor: {reg_factor}") if max_train_steps == 0: # calculate max_train_steps @@ -689,13 +689,18 @@ def train_model( float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) ) - if lr_warmup != 0: - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - else: - lr_warmup_steps = 0 - log.info(f"Total steps: {total_steps}") + # Calculate lr_warmup_steps + if lr_warmup_steps > 0: + lr_warmup_steps = int(lr_warmup_steps) + if lr_warmup > 0: + log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.") + elif lr_warmup != 0: + lr_warmup_steps = lr_warmup / 100 + else: + lr_warmup_steps = 0 + log.info(f"Train batch size: {train_batch_size}") log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}") log.info(f"Epoch: {epoch}") @@ -757,6 +762,7 @@ def train_model( "clip_skip": clip_skip if clip_skip != 0 else None, "color_aug": color_aug, "dataset_config": dataset_config, + "disable_mmap_load_safetensors": disable_mmap_load_safetensors, "dynamo_backend": dynamo_backend, "enable_bucket": enable_bucket, "epoch": int(epoch), @@ -777,6 +783,7 @@ def train_model( "keep_tokens": int(keep_tokens), "learning_rate": learning_rate, "logging_dir": logging_dir, + "log_config": log_config, "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, "loss_type": loss_type, @@ -786,6 +793,7 @@ def train_model( int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch) ), "lr_scheduler_power": lr_scheduler_power, + "lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None, "lr_warmup_steps": lr_warmup_steps, "max_bucket_reso": max_bucket_reso, "max_timestep": max_timestep if max_timestep != 0 else None, @@ -840,6 +848,10 @@ def train_model( "save_last_n_steps_state": ( save_last_n_steps_state if save_last_n_steps_state != 0 else None ), + "save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None, + "save_last_n_epochs_state": ( + save_last_n_epochs_state if save_last_n_epochs_state != 0 else None + ), "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, @@ -849,6 +861,7 @@ def train_model( "sdpa": True if xformers == "sdpa" else None, "seed": int(seed) if int(seed) != 0 else None, "shuffle_caption": shuffle_caption, + "skip_cache_check": skip_cache_check, "stop_text_encoder_training": ( stop_text_encoder_training if stop_text_encoder_training != 0 else None ), @@ -862,8 +875,8 @@ def train_model( "vae": vae, "vae_batch_size": vae_batch_size if vae_batch_size != 0 else None, "wandb_api_key": wandb_api_key, - "wandb_run_name": wandb_run_name, - "weigts": weights, + "wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name, + "weights": weights, "use_object_template": True if template == "object template" else None, "use_style_template": True if template == "style template" else None, "xformers": True if xformers == "xformers" else None, @@ -1130,6 +1143,7 @@ def list_embedding_files(path): basic_training.learning_rate, basic_training.lr_scheduler, basic_training.lr_warmup, + basic_training.lr_warmup_steps, basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, @@ -1194,6 +1208,7 @@ def list_embedding_files(path): basic_training.optimizer, basic_training.optimizer_args, basic_training.lr_scheduler_args, + basic_training.lr_scheduler_type, advanced_training.noise_offset_type, advanced_training.noise_offset, advanced_training.noise_offset_random_strength, @@ -1215,12 +1230,17 @@ def list_embedding_files(path): advanced_training.save_every_n_steps, advanced_training.save_last_n_steps, advanced_training.save_last_n_steps_state, + advanced_training.save_last_n_epochs, + advanced_training.save_last_n_epochs_state, + advanced_training.skip_cache_check, advanced_training.log_with, advanced_training.wandb_api_key, advanced_training.wandb_run_name, advanced_training.log_tracker_name, advanced_training.log_tracker_config, + advanced_training.log_config, advanced_training.scale_v_pred_loss_like_noise_pred, + sdxl_params.disable_mmap_load_safetensors, advanced_training.min_timestep, advanced_training.max_timestep, sdxl_params.sdxl_no_half_vae, diff --git a/presets/dreambooth/sd3_bdsqlsz_v1.json b/presets/dreambooth/sd3_bdsqlsz_v1.json new file mode 100644 index 000000000..22ce46e5f --- /dev/null +++ b/presets/dreambooth/sd3_bdsqlsz_v1.json @@ -0,0 +1,146 @@ +{ + "adaptive_noise_scale": 0, + "additional_parameters": "", + "async_upload": false, + "bucket_no_upscale": true, + "bucket_reso_steps": 64, + "cache_latents": true, + "cache_latents_to_disk": true, + "caption_dropout_every_n_epochs": 0, + "caption_dropout_rate": 0, + "caption_extension": ".txt", + "clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors", + "clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors", + "clip_skip": 1, + "color_aug": false, + "dataset_config": "", + "debiased_estimation_loss": false, + "disable_mmap_load_safetensors": false, + "dynamo_backend": "no", + "dynamo_mode": "default", + "dynamo_use_dynamic": false, + "dynamo_use_fullgraph": false, + "enable_bucket": true, + "epoch": 8, + "extra_accelerate_launch_args": "", + "flip_aug": false, + "full_bf16": false, + "full_fp16": false, + "fused_backward_pass": false, + "fused_optimizer_groups": 0, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": true, + "huber_c": 0.1, + "huber_schedule": "snr", + "huggingface_path_in_repo": "", + "huggingface_repo_id": "", + "huggingface_repo_type": "", + "huggingface_repo_visibility": "", + "huggingface_token": "", + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, + "keep_tokens": 0, + "learning_rate": 5e-06, + "learning_rate_te": 0, + "learning_rate_te1": 1e-05, + "learning_rate_te2": 1e-05, + "log_config": false, + "log_tracker_config": "", + "log_tracker_name": "", + "log_with": "", + "logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3", + "logit_mean": 0, + "logit_std": 1, + "loss_type": "l2", + "lr_scheduler": "cosine", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": 1, + "lr_scheduler_power": 1, + "lr_scheduler_type": "", + "lr_warmup": 10, + "main_process_port": 0, + "masked_loss": false, + "max_bucket_reso": 1536, + "max_data_loader_n_workers": 0, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": 225, + "max_train_epochs": 8, + "max_train_steps": 1600, + "mem_eff_attn": false, + "metadata_author": "", + "metadata_description": "", + "metadata_license": "", + "metadata_tags": "", + "metadata_title": "", + "min_bucket_reso": 256, + "min_snr_gamma": 0, + "min_timestep": 0, + "mixed_precision": "bf16", + "mode_scale": 1.29, + "model_list": "custom", + "multi_gpu": false, + "multires_noise_discount": 0.3, + "multires_noise_iterations": 0, + "no_token_padding": false, + "noise_offset": 0, + "noise_offset_random_strength": false, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "PagedAdamW8bit", + "optimizer_args": "weight_decay=0.1 betas=.9,.95", + "output_dir": "E:/models/sd3", + "output_name": "sd3", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors", + "prior_loss_weight": 1, + "random_crop": false, + "reg_data_dir": "", + "resume": "", + "resume_from_huggingface": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 0, + "sample_prompts": "", + "sample_sampler": "euler_a", + "save_as_bool": false, + "save_clip": false, + "save_every_n_epochs": 0, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "save_state_to_huggingface": false, + "save_t5xxl": false, + "scale_v_pred_loss_like_noise_pred": false, + "sd3_cache_text_encoder_outputs": true, + "sd3_cache_text_encoder_outputs_to_disk": true, + "sd3_checkbox": true, + "sd3_text_encoder_batch_size": 1, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": false, + "seed": 1026, + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors", + "t5xxl_device": "", + "t5xxl_dtype": "bf16", + "train_batch_size": 1, + "train_data_dir": "C:/Users/berna/Downloads/martini/img2", + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "weighting_scheme": "logit_normal", + "xformers": "sdpa" +} \ No newline at end of file diff --git a/presets/dreambooth/sd3_bdsqlsz_v2.json b/presets/dreambooth/sd3_bdsqlsz_v2.json new file mode 100644 index 000000000..0b50c4533 --- /dev/null +++ b/presets/dreambooth/sd3_bdsqlsz_v2.json @@ -0,0 +1,146 @@ +{ + "adaptive_noise_scale": 0, + "additional_parameters": "", + "async_upload": false, + "bucket_no_upscale": true, + "bucket_reso_steps": 64, + "cache_latents": true, + "cache_latents_to_disk": true, + "caption_dropout_every_n_epochs": 0, + "caption_dropout_rate": 0, + "caption_extension": ".txt", + "clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors", + "clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors", + "clip_skip": 1, + "color_aug": false, + "dataset_config": "", + "debiased_estimation_loss": false, + "disable_mmap_load_safetensors": false, + "dynamo_backend": "no", + "dynamo_mode": "default", + "dynamo_use_dynamic": false, + "dynamo_use_fullgraph": false, + "enable_bucket": true, + "epoch": 8, + "extra_accelerate_launch_args": "", + "flip_aug": false, + "full_bf16": false, + "full_fp16": false, + "fused_backward_pass": false, + "fused_optimizer_groups": 0, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": true, + "huber_c": 0.1, + "huber_schedule": "snr", + "huggingface_path_in_repo": "", + "huggingface_repo_id": "", + "huggingface_repo_type": "", + "huggingface_repo_visibility": "", + "huggingface_token": "", + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, + "keep_tokens": 0, + "learning_rate": 5e-06, + "learning_rate_te": 0, + "learning_rate_te1": 1e-05, + "learning_rate_te2": 1e-05, + "log_config": false, + "log_tracker_config": "", + "log_tracker_name": "", + "log_with": "", + "logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3", + "logit_mean": 0, + "logit_std": 1, + "loss_type": "l2", + "lr_scheduler": "cosine", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": 1, + "lr_scheduler_power": 1, + "lr_scheduler_type": "", + "lr_warmup": 10, + "main_process_port": 0, + "masked_loss": false, + "max_bucket_reso": 1536, + "max_data_loader_n_workers": 0, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": 150, + "max_train_epochs": 8, + "max_train_steps": 1600, + "mem_eff_attn": false, + "metadata_author": "", + "metadata_description": "", + "metadata_license": "", + "metadata_tags": "", + "metadata_title": "", + "min_bucket_reso": 256, + "min_snr_gamma": 0, + "min_timestep": 0, + "mixed_precision": "bf16", + "mode_scale": 1.29, + "model_list": "custom", + "multi_gpu": false, + "multires_noise_discount": 0.3, + "multires_noise_iterations": 0, + "no_token_padding": false, + "noise_offset": 0, + "noise_offset_random_strength": false, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "PagedAdamW8bit", + "optimizer_args": "weight_decay=0.1 betas=.9,.95", + "output_dir": "E:/models/sd3", + "output_name": "sd3_v2", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors", + "prior_loss_weight": 1, + "random_crop": false, + "reg_data_dir": "", + "resume": "", + "resume_from_huggingface": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 0, + "sample_prompts": "", + "sample_sampler": "euler_a", + "save_as_bool": false, + "save_clip": false, + "save_every_n_epochs": 0, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "save_state_to_huggingface": false, + "save_t5xxl": false, + "scale_v_pred_loss_like_noise_pred": false, + "sd3_cache_text_encoder_outputs": true, + "sd3_cache_text_encoder_outputs_to_disk": true, + "sd3_checkbox": true, + "sd3_text_encoder_batch_size": 1, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": false, + "seed": 1026, + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors", + "t5xxl_device": "", + "t5xxl_dtype": "bf16", + "train_batch_size": 1, + "train_data_dir": "C:/Users/berna/Downloads/martini/img", + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "weighting_scheme": "logit_normal", + "xformers": "sdpa" +} \ No newline at end of file diff --git a/presets/lora/flux1D - adamw8bit fp8.json b/presets/lora/flux1D - adamw8bit fp8.json new file mode 100644 index 000000000..c3a654e78 --- /dev/null +++ b/presets/lora/flux1D - adamw8bit fp8.json @@ -0,0 +1,182 @@ +{ + "LoRA_type": "Flux1", + "LyCORIS_preset": "full", + "adaptive_noise_scale": 0, + "additional_parameters": "", + "ae": "put the full path to ae.sft here", + "apply_t5_attn_mask": true, + "async_upload": false, + "block_alphas": "", + "block_dims": "", + "block_lr_zero_threshold": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 64, + "bypass_mode": false, + "cache_latents": true, + "cache_latents_to_disk": true, + "caption_dropout_every_n_epochs": 0, + "caption_dropout_rate": 0, + "caption_extension": ".txt", + "clip_l": "put the full path to clip_l.safetensors here", + "clip_skip": 1, + "color_aug": false, + "constrain": 0, + "conv_alpha": 1, + "conv_block_alphas": "", + "conv_block_dims": "", + "conv_dim": 1, + "dataset_config": "", + "debiased_estimation_loss": false, + "decompose_both": false, + "dim_from_weights": false, + "discrete_flow_shift": 3, + "dora_wd": false, + "down_lr_weight": "", + "dynamo_backend": "no", + "dynamo_mode": "default", + "dynamo_use_dynamic": false, + "dynamo_use_fullgraph": false, + "enable_bucket": true, + "epoch": 1, + "extra_accelerate_launch_args": "", + "factor": -1, + "flip_aug": false, + "flux1_cache_text_encoder_outputs": true, + "flux1_cache_text_encoder_outputs_to_disk": true, + "flux1_checkbox": true, + "fp8_base": true, + "full_bf16": true, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": true, + "guidance_scale": 1, + "highvram": false, + "huber_c": 0.1, + "huber_schedule": "snr", + "huggingface_path_in_repo": "", + "huggingface_repo_id": "", + "huggingface_repo_type": "", + "huggingface_repo_visibility": "", + "huggingface_token": "", + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, + "keep_tokens": 0, + "learning_rate": 0.0003, + "log_config": false, + "log_tracker_config": "", + "log_tracker_name": "", + "log_with": "", + "logging_dir": "./test/logs-saruman", + "loraplus_lr_ratio": 0, + "loraplus_text_encoder_lr_ratio": 0, + "loraplus_unet_lr_ratio": 0, + "loss_type": "l2", + "lowvram": false, + "lr_scheduler": "constant", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": 1, + "lr_scheduler_power": 1, + "lr_scheduler_type": "", + "lr_warmup": 0, + "main_process_port": 0, + "masked_loss": false, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": 0, + "max_grad_norm": 1, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": 75, + "max_train_epochs": 0, + "max_train_steps": 1000, + "mem_eff_attn": false, + "mem_eff_save": false, + "metadata_author": "", + "metadata_description": "", + "metadata_license": "", + "metadata_tags": "", + "metadata_title": "", + "mid_lr_weight": "", + "min_bucket_reso": 256, + "min_snr_gamma": 7, + "min_timestep": 0, + "mixed_precision": "bf16", + "model_list": "custom", + "model_prediction_type": "raw", + "module_dropout": 0, + "multi_gpu": false, + "multires_noise_discount": 0.3, + "multires_noise_iterations": 0, + "network_alpha": 16, + "network_dim": 16, + "network_dropout": 0, + "network_weights": "", + "noise_offset": 0.05, + "noise_offset_random_strength": false, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "AdamW8bit", + "optimizer_args": "", + "output_dir": "put the full path to output folder here", + "output_name": "Flux.my-super-duper-model-name-goes-here-v1.0", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "put the full path to flux1-dev.safetensors here", + "prior_loss_weight": 1, + "random_crop": false, + "rank_dropout": 0, + "rank_dropout_scale": false, + "reg_data_dir": "", + "rescaled": false, + "resume": "", + "resume_from_huggingface": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 0, + "sample_prompts": "saruman posing under a stormy lightning sky, photorealistic --w 832 --h 1216 --s 20 --l 4 --d 42", + "sample_sampler": "euler", + "save_as_bool": false, + "save_every_n_epochs": 1, + "save_every_n_steps": 50, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "bf16", + "save_state": false, + "save_state_on_train_end": false, + "save_state_to_huggingface": false, + "scale_v_pred_loss_like_noise_pred": false, + "scale_weight_norms": 0, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": true, + "sdxl_no_half_vae": true, + "seed": 42, + "shuffle_caption": false, + "split_mode": false, + "stop_text_encoder_training": 0, + "t5xxl": "put the full path to the file here. Use the fp16 version", + "t5xxl_max_token_length": 512, + "text_encoder_lr": 0, + "timestep_sampling": "sigmoid", + "train_batch_size": 1, + "train_blocks": "all", + "train_data_dir": "put your image folder here", + "train_norm": false, + "train_on_input": true, + "training_comment": "", + "unet_lr": 0.0003, + "unit": 1, + "up_lr_weight": "", + "use_cp": false, + "use_scalar": false, + "use_tucker": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "xformers": "sdpa" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b5769ba8d..235cbfb6a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,35 +1,37 @@ -accelerate==0.25.0 +accelerate==0.33.0 aiofiles==23.2.1 altair==4.2.2 -dadaptation==3.1 +dadaptation==3.2 diffusers[torch]==0.25.0 easygui==0.98.3 einops==0.7.0 fairscale==0.4.13 ftfy==6.1.1 -gradio==4.43.0 -huggingface-hub==0.20.1 +gradio==5.4.0 +huggingface-hub==0.25.2 imagesize==1.4.1 invisible-watermark==0.2.0 lion-pytorch==0.0.6 -lycoris_lora==2.2.0.post3 +lycoris_lora==3.1.0 omegaconf==2.3.0 onnx==1.16.1 prodigyopt==1.0 protobuf==3.20.3 open-clip-torch==2.20.0 -opencv-python==4.7.0.68 +opencv-python==4.10.0.84 prodigyopt==1.0 pytorch-lightning==1.9.0 rich>=13.7.1 -safetensors==0.4.2 +safetensors==0.4.4 +schedulefree==1.2.7 scipy==1.11.4 +# for T5XXL tokenizer (SD3/FLUX) +sentencepiece==0.2.0 timm==0.6.12 tk==0.1.0 toml==0.10.2 -transformers==4.38.0 +transformers==4.44.2 voluptuous==0.13.1 -wandb==0.15.11 -scipy==1.11.4 -# for kohya_ss library --e ./sd-scripts # no_verify leave this to specify not checking this a verification stage +wandb==0.18.0 +# for kohya_ss sd-scripts library +-e ./sd-scripts diff --git a/requirements_linux.txt b/requirements_linux.txt index 41275f63a..57394f8bc 100644 --- a/requirements_linux.txt +++ b/requirements_linux.txt @@ -1,5 +1,13 @@ -torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 -bitsandbytes==0.43.0 -tensorboard==2.15.2 tensorflow==2.15.0.post1 -onnxruntime-gpu==1.17.1 +# Custom index URL for specific packages +--extra-index-url https://download.pytorch.org/whl/cu124 + +torch==2.5.0+cu124 +torchvision==0.20.0+cu124 +xformers==0.0.28.post2 + +bitsandbytes==0.44.0 +tensorboard==2.15.2 +tensorflow==2.15.0.post1 +onnxruntime-gpu==1.19.2 + -r requirements.txt diff --git a/requirements_linux_docker.txt b/requirements_linux_docker.txt index 779ed6d8b..d0ae66d53 100644 --- a/requirements_linux_docker.txt +++ b/requirements_linux_docker.txt @@ -1,4 +1,4 @@ xformers>=0.0.20 -bitsandbytes==0.43.0 -accelerate==0.25.0 -tensorboard \ No newline at end of file +bitsandbytes==0.44.0 +accelerate==0.33.0 +tensorboard diff --git a/requirements_linux_ipex.txt b/requirements_linux_ipex.txt index f794a9046..41a26daca 100644 --- a/requirements_linux_ipex.txt +++ b/requirements_linux_ipex.txt @@ -1,5 +1,17 @@ -torch==2.1.0.post0+cxx11.abi torchvision==0.16.0.post0+cxx11.abi intel-extension-for-pytorch==2.1.20+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -tensorboard==2.15.2 tensorflow==2.15.0 intel-extension-for-tensorflow[xpu]==2.15.0.0 -mkl==2024.1.0 mkl-dpcpp==2024.1.0 oneccl-devel==2021.12.0 impi-devel==2021.12.0 -onnxruntime-openvino==1.17.1 +# Custom index URL for specific packages +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ + +torch==2.1.0.post3+cxx11.abi +torchvision==0.16.0.post3+cxx11.abi +intel-extension-for-pytorch==2.1.40+xpu +oneccl_bind_pt==2.1.400+xpu + +tensorflow==2.15.1 +intel-extension-for-tensorflow[xpu]==2.15.0.1 +mkl==2024.2.0 +mkl-dpcpp==2024.2.0 +oneccl-devel==2021.13.0 +impi-devel==2021.13.0 +onnxruntime-openvino==1.18.0 + -r requirements.txt diff --git a/requirements_linux_rocm.txt b/requirements_linux_rocm.txt index 570ace0a2..187ec9ed7 100644 --- a/requirements_linux_rocm.txt +++ b/requirements_linux_rocm.txt @@ -1,4 +1,13 @@ -torch==2.3.0+rocm6.0 torchvision==0.18.0+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0 -tensorboard==2.14.1 tensorflow-rocm==2.14.0.600 -onnxruntime-training --pre --index-url https://pypi.lsh.sh/60/ --extra-index-url https://pypi.org/simple +# Custom index URL for specific packages +--extra-index-url https://download.pytorch.org/whl/rocm6.1 +torch==2.5.0+rocm6.1 +torchvision==0.20.0+rocm6.1 + +tensorboard==2.14.1 +tensorflow-rocm==2.14.0.600 + +# Custom index URL for specific packages +--extra-index-url https://pypi.lsh.sh/60/ +onnxruntime-training --pre + -r requirements.txt diff --git a/requirements_macos_amd64.txt b/requirements_macos_amd64.txt index 571d9b6ef..5d65837ef 100644 --- a/requirements_macos_amd64.txt +++ b/requirements_macos_amd64.txt @@ -1,5 +1,5 @@ torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html -xformers bitsandbytes==0.41.1 +xformers bitsandbytes==0.43.3 tensorflow-macos tensorboard==2.14.1 onnxruntime==1.17.1 -r requirements.txt diff --git a/requirements_macos_arm64.txt b/requirements_macos_arm64.txt index 96acb97c3..364c44ad5 100644 --- a/requirements_macos_arm64.txt +++ b/requirements_macos_arm64.txt @@ -1,5 +1,5 @@ torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html -xformers bitsandbytes==0.41.1 +xformers bitsandbytes==0.43.3 tensorflow-macos tensorflow-metal tensorboard==2.14.1 onnxruntime==1.17.1 -r requirements.txt diff --git a/requirements_pytorch_windows.txt b/requirements_pytorch_windows.txt index 23364d1af..83564f7ba 100644 --- a/requirements_pytorch_windows.txt +++ b/requirements_pytorch_windows.txt @@ -1,3 +1,8 @@ -torch==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 \ No newline at end of file +# Custom index URL for specific packages +--extra-index-url https://download.pytorch.org/whl/cu124 + +torch==2.5.0+cu124 +torchvision==0.20.0+cu124 +xformers==0.0.28.post2 + +-r requirements_windows.txt \ No newline at end of file diff --git a/requirements_runpod.txt b/requirements_runpod.txt index 481da43d4..080402796 100644 --- a/requirements_runpod.txt +++ b/requirements_runpod.txt @@ -1,6 +1,13 @@ -torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage -bitsandbytes==0.43.0 -tensorboard==2.14.1 tensorflow==2.14.0 wheel +--extra-index-url https://download.pytorch.org/whl/cu124 +torch==2.5.0+cu124 +torchvision==0.20.0+cu124 +xformers==0.0.28.post2 + +bitsandbytes==0.44.0 +tensorboard==2.14.1 +tensorflow==2.14.0 +wheel tensorrt -onnxruntime-gpu==1.17.1 +onnxruntime-gpu==1.19.2 + -r requirements.txt diff --git a/requirements_windows.txt b/requirements_windows.txt index a9300090c..3eca950bc 100644 --- a/requirements_windows.txt +++ b/requirements_windows.txt @@ -1,5 +1,6 @@ -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 tensorboard tensorflow>=2.16.1 -onnxruntime-gpu==1.17.1 +onnxruntime-gpu==1.19.2 + -r requirements.txt \ No newline at end of file diff --git a/sd-scripts b/sd-scripts index b8896aad4..264328d11 160000 --- a/sd-scripts +++ b/sd-scripts @@ -1 +1 @@ -Subproject commit b8896aad400222c8c4441b217fda0f9bb0807ffd +Subproject commit 264328d117dc5d17772ec0bdbac2b9f0cf4695f5 diff --git a/setup-3.10.bat b/setup-3.10.bat index 6b887f59a..2b26db245 100644 --- a/setup-3.10.bat +++ b/setup-3.10.bat @@ -2,7 +2,7 @@ IF NOT EXIST venv ( echo Creating venv... - py -3.10 -m venv venv + py -3.10.11 -m venv venv ) :: Create the directory if it doesn't exist @@ -13,6 +13,9 @@ call .\venv\Scripts\deactivate.bat call .\venv\Scripts\activate.bat +REM first make sure we have setuptools available in the venv +python -m pip install --require-virtualenv --no-input -q -q setuptools + REM Check if the batch was started via double-click IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" ( REM echo This script was started by double clicking. diff --git a/setup/setup_common.py b/setup/setup_common.py index 8e35b74f2..d02546310 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -1,363 +1,321 @@ -import subprocess import os -import re import sys import logging import shutil import datetime +import subprocess +import re import pkg_resources -errors = 0 # Define the 'errors' variable before using it -log = logging.getLogger('sd') +log = logging.getLogger("sd") + +# Constants +MIN_PYTHON_VERSION = (3, 10, 9) +MAX_PYTHON_VERSION = (3, 11, 0) +LOG_DIR = "../logs/setup/" +LOG_LEVEL = "INFO" # Set to "INFO" or "WARNING" for less verbose logging + def check_python_version(): """ Check if the current Python version is within the acceptable range. - Returns: - bool: True if the current Python version is valid, False otherwise. + bool: True if the current Python version is valid, False otherwise. """ - min_version = (3, 10, 9) - max_version = (3, 11, 0) - - from packaging import version - + log.debug("Checking Python version...") try: current_version = sys.version_info log.info(f"Python version is {sys.version}") - - if not (min_version <= current_version < max_version): - log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI") - log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0") + + if not (MIN_PYTHON_VERSION <= current_version < MAX_PYTHON_VERSION): + log.error( + f"The current version of python ({sys.version}) is not supported." + ) + log.error("The Python version must be >= 3.10.9 and < 3.11.0.") return False return True except Exception as e: log.error(f"Failed to verify Python version. Error: {e}") return False + def update_submodule(quiet=True): """ Ensure the submodule is initialized and updated. - - This function uses the Git command line interface to initialize and update - the specified submodule recursively. Errors during the Git operation - or if Git is not found are caught and logged. - - Parameters: - - quiet: If True, suppresses the output of the Git command. """ + log.debug("Updating submodule...") git_command = ["git", "submodule", "update", "--init", "--recursive"] - if quiet: git_command.append("--quiet") - + try: - # Initialize and update the submodule subprocess.run(git_command, check=True) log.info("Submodule initialized and updated.") - except subprocess.CalledProcessError as e: - # Log the error if the Git operation fails log.error(f"Error during Git operation: {e}") except FileNotFoundError as e: - # Log the error if the file is not found log.error(e) -# def read_tag_version_from_file(file_path): -# """ -# Read the tag version from a given file. - -# Parameters: -# - file_path: The path to the file containing the tag version. - -# Returns: -# The tag version as a string. -# """ -# with open(file_path, 'r') as file: -# # Read the first line and strip whitespace -# tag_version = file.readline().strip() -# return tag_version def clone_or_checkout(repo_url, branch_or_tag, directory_name): """ Clone a repo or checkout a specific branch or tag if the repo already exists. - For branches, it updates to the latest version before checking out. - Suppresses detached HEAD advice for tags or specific commits. - Restores the original working directory after operations. - - Parameters: - - repo_url: The URL of the Git repository. - - branch_or_tag: The name of the branch or tag to clone or checkout. - - directory_name: The name of the directory to clone into or where the repo already exists. """ - original_dir = os.getcwd() # Store the original directory + log.debug( + f"Cloning or checking out repository: {repo_url}, branch/tag: {branch_or_tag}, directory: {directory_name}" + ) + original_dir = os.getcwd() try: if not os.path.exists(directory_name): - # Directory does not exist, clone the repo quietly - - # Construct the command as a string for logging - # run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}" - run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name] - - - # Log the command - log.debug(run_cmd) - - # Run the command - process = subprocess.Popen( - run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True - ) - output, error = process.communicate() - - if error and not error.startswith("Note: switching to"): - log.warning(error) - else: - log.info(f"Successfully cloned sd-scripts {branch_or_tag}") - + run_cmd = [ + "git", + "clone", + "--branch", + branch_or_tag, + "--single-branch", + "--quiet", + repo_url, + directory_name, + ] + log.debug(f"Cloning repository: {run_cmd}") + subprocess.run(run_cmd, check=True) + log.info(f"Successfully cloned {repo_url} ({branch_or_tag})") else: os.chdir(directory_name) + log.debug("Fetching all branches and tags...") subprocess.run(["git", "fetch", "--all", "--quiet"], check=True) - subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True) - - # Get the current branch or commit hash - current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode() - tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode() - - if current_branch_hash != tag_branch_hash: - run_cmd = f"git checkout {branch_or_tag} --quiet" - # Log the command - log.debug(run_cmd) - - # Execute the checkout command - process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - output, error = process.communicate() - - if error: - log.warning(error.decode()) - else: - log.info(f"Checked out sd-scripts {branch_or_tag} successfully.") + subprocess.run( + ["git", "config", "advice.detachedHead", "false"], check=True + ) + + current_branch_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode() + ) + target_branch_hash = ( + subprocess.check_output(["git", "rev-parse", branch_or_tag]) + .strip() + .decode() + ) + + if current_branch_hash != target_branch_hash: + log.debug(f"Checking out branch/tag: {branch_or_tag}") + subprocess.run( + ["git", "checkout", branch_or_tag, "--quiet"], check=True + ) + log.info(f"Checked out {branch_or_tag} successfully.") else: - log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.") + log.info(f"Already at required branch/tag: {branch_or_tag}") except subprocess.CalledProcessError as e: log.error(f"Error during Git operation: {e}") finally: - os.chdir(original_dir) # Restore the original directory + os.chdir(original_dir) -# setup console and file logging -def setup_logging(clean=False): - # - # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master - # + +def setup_logging(): + """ + Set up logging to file and console. + """ + log.debug("Setting up logging...") from rich.theme import Theme from rich.logging import RichHandler from rich.console import Console - from rich.pretty import install as pretty_install - from rich.traceback import install as traceback_install console = Console( log_time=True, - log_time_format='%H:%M:%S-%f', - theme=Theme( - { - 'traceback.border': 'black', - 'traceback.border.syntax_error': 'black', - 'inspect.value.border': 'black', - } - ), + log_time_format="%H:%M:%S-%f", + theme=Theme({"traceback.border": "black", "inspect.value.border": "black"}), ) - # logging.getLogger("urllib3").setLevel(logging.ERROR) - # logging.getLogger("httpx").setLevel(logging.ERROR) - - current_datetime = datetime.datetime.now() - current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S') + current_datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_file = os.path.join( - os.path.dirname(__file__), - f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log', + os.path.dirname(__file__), f"{LOG_DIR}kohya_ss_gui_{current_datetime_str}.log" ) + os.makedirs(os.path.dirname(log_file), exist_ok=True) - # Create directories if they don't exist - log_directory = os.path.dirname(log_file) - os.makedirs(log_directory, exist_ok=True) - - level = logging.INFO logging.basicConfig( level=logging.ERROR, - format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', + format="%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s", filename=log_file, - filemode='a', - encoding='utf-8', + filemode="a", + encoding="utf-8", force=True, ) - log.setLevel( - logging.DEBUG - ) # log to file is always at level debug for facility `sd` - pretty_install(console=console) - traceback_install( - console=console, - extra_lines=1, - width=console.width, - word_wrap=False, - indent_guides=False, - suppress=[], - ) - rh = RichHandler( - show_time=True, - omit_repeated_times=False, - show_level=True, - show_path=False, - markup=False, - rich_tracebacks=True, - log_time_format='%H:%M:%S-%f', - level=level, - console=console, - ) - rh.set_name(level) - while log.hasHandlers() and len(log.handlers) > 0: - log.removeHandler(log.handlers[0]) - log.addHandler(rh) + log_level = os.getenv("LOG_LEVEL", LOG_LEVEL).upper() + log.setLevel(getattr(logging, log_level, logging.DEBUG)) + rich_handler = RichHandler(console=console) + + # Replace existing handlers with the rich handler + log.handlers.clear() + log.addHandler(rich_handler) + log.debug("Logging setup complete.") -def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False): +def install_requirements_inbulk( + requirements_file, show_stdout=True, optional_parm="", upgrade=False +): + log.debug(f"Installing requirements in bulk from: {requirements_file}") if not os.path.exists(requirements_file): - log.error(f'Could not find the requirements file in {requirements_file}.') + log.error(f"Could not find the requirements file in {requirements_file}.") return - log.info(f'Installing requirements from {requirements_file}...') + log.info(f"Installing/Validating requirements from {requirements_file}...") + # Build the command as a list + cmd = ["pip", "install", "-r", requirements_file] if upgrade: - optional_parm += " -U" + cmd.append("--upgrade") + if not show_stdout: + cmd.append("--quiet") + if optional_parm: + cmd.extend(optional_parm.split()) - if show_stdout: - run_cmd(f'pip install -r {requirements_file} {optional_parm}') - else: - run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet') - log.info(f'Requirements from {requirements_file} installed.') - + try: + # Run the command and filter output in real-time + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True + ) + for line in process.stdout: + if "Requirement already satisfied" not in line: + log.info(line.strip()) if show_stdout else None -def configure_accelerate(run_accelerate=False): - # - # This function was taken and adapted from code written by jstayco - # + # Capture and log any errors + _, stderr = process.communicate() + if process.returncode != 0: + log.error(f"Failed to install requirements: {stderr.strip()}") + + except subprocess.CalledProcessError as e: + log.error(f"An error occurred while installing requirements: {e}") + +def configure_accelerate(run_accelerate=False): + log.debug("Configuring accelerate...") from pathlib import Path def env_var_exists(var_name): - return var_name in os.environ and os.environ[var_name] != '' + return var_name in os.environ and os.environ[var_name] != "" + + log.info("Configuring accelerate...") - log.info('Configuring accelerate...') - source_accelerate_config_file = os.path.join( os.path.dirname(os.path.abspath(__file__)), - '..', - 'config_files', - 'accelerate', - 'default_config.yaml', + "..", + "config_files", + "accelerate", + "default_config.yaml", ) if not os.path.exists(source_accelerate_config_file): + log.warning( + f"Could not find the accelerate configuration file in {source_accelerate_config_file}." + ) if run_accelerate: - run_cmd('accelerate config') + log.debug("Running accelerate configuration command...") + run_cmd([sys.executable, "-m", "accelerate", "config"]) else: log.warning( - f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.' + "Please configure accelerate manually by running the option in the menu." ) - - log.debug( - f'Source accelerate config location: {source_accelerate_config_file}' - ) + return + + log.debug(f"Source accelerate config location: {source_accelerate_config_file}") target_config_location = None - log.debug( - f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, " - f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, " - f"USERPROFILE: {os.environ.get('USERPROFILE')}" - ) - if env_var_exists('HF_HOME'): - target_config_location = Path( - os.environ['HF_HOME'], 'accelerate', 'default_config.yaml' - ) - elif env_var_exists('LOCALAPPDATA'): - target_config_location = Path( - os.environ['LOCALAPPDATA'], - 'huggingface', - 'accelerate', - 'default_config.yaml', - ) - elif env_var_exists('USERPROFILE'): - target_config_location = Path( - os.environ['USERPROFILE'], - '.cache', - 'huggingface', - 'accelerate', - 'default_config.yaml', - ) + env_vars = { + "HF_HOME": Path(os.environ.get("HF_HOME", "")), + "LOCALAPPDATA": Path( + os.environ.get("LOCALAPPDATA", ""), + "huggingface", + "accelerate", + "default_config.yaml", + ), + "USERPROFILE": Path( + os.environ.get("USERPROFILE", ""), + ".cache", + "huggingface", + "accelerate", + "default_config.yaml", + ), + } - log.debug(f'Target config location: {target_config_location}') + for var, path in env_vars.items(): + if env_var_exists(var): + target_config_location = path + break + + log.debug(f"Target config location: {target_config_location}") if target_config_location: if not target_config_location.is_file(): - target_config_location.parent.mkdir(parents=True, exist_ok=True) log.debug( - f'Target accelerate config location: {target_config_location}' + f"Creating target config directory: {target_config_location.parent}" ) - shutil.copyfile( - source_accelerate_config_file, target_config_location - ) - log.info( - f'Copied accelerate config file to: {target_config_location}' + target_config_location.parent.mkdir(parents=True, exist_ok=True) + log.debug( + f"Copying config file to target location: {target_config_location}" ) - else: - if run_accelerate: - run_cmd('accelerate config') - else: - log.warning( - 'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' - ) - else: - if run_accelerate: - run_cmd('accelerate config') + shutil.copyfile(source_accelerate_config_file, target_config_location) + log.info(f"Copied accelerate config file to: {target_config_location}") + elif run_accelerate: + log.debug("Running accelerate configuration command...") + run_cmd([sys.executable, "-m", "accelerate", "config"]) else: log.warning( - 'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' + "Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config." ) + elif run_accelerate: + log.debug("Running accelerate configuration command...") + run_cmd([sys.executable, "-m", "accelerate", "config"]) + else: + log.warning( + "Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config." + ) def check_torch(): + log.debug("Checking Torch installation...") # - # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master + # This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master # # Check for toolkit - if shutil.which('nvidia-smi') is not None or os.path.exists( + if shutil.which("nvidia-smi") is not None or os.path.exists( os.path.join( - os.environ.get('SystemRoot') or r'C:\Windows', - 'System32', - 'nvidia-smi.exe', + os.environ.get("SystemRoot") or r"C:\Windows", + "System32", + "nvidia-smi.exe", ) ): - log.info('nVidia toolkit detected') - elif shutil.which('rocminfo') is not None or os.path.exists( - '/opt/rocm/bin/rocminfo' + log.info("nVidia toolkit detected") + elif shutil.which("rocminfo") is not None or os.path.exists( + "/opt/rocm/bin/rocminfo" ): - log.info('AMD toolkit detected') - elif (shutil.which('sycl-ls') is not None - or os.environ.get('ONEAPI_ROOT') is not None - or os.path.exists('/opt/intel/oneapi')): - log.info('Intel OneAPI toolkit detected') + log.info("AMD toolkit detected") + elif ( + shutil.which("sycl-ls") is not None + or os.environ.get("ONEAPI_ROOT") is not None + or os.path.exists("/opt/intel/oneapi") + ): + log.info("Intel OneAPI toolkit detected") else: - log.info('Using CPU-only Torch') + log.info("Using CPU-only Torch") try: import torch + + log.debug("Torch module imported successfully.") try: # Import IPEX / XPU support import intel_extension_for_pytorch as ipex - except Exception: - pass - log.info(f'Torch {torch.__version__}') + + log.debug("Intel extension for PyTorch imported successfully.") + except Exception as e: + log.warning(f"Failed to import intel_extension_for_pytorch: {e}") + log.info(f"Torch {torch.__version__}") if torch.cuda.is_available(): if torch.version.cuda: @@ -367,33 +325,33 @@ def check_torch(): ) elif torch.version.hip: # Log AMD ROCm HIP version - log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') + log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}") else: - log.warning('Unknown Torch backend') + log.warning("Unknown Torch backend") # Log information about detected GPUs for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' + f"Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}" ) # Check if XPU is available elif hasattr(torch, "xpu") and torch.xpu.is_available(): # Log Intel IPEX version - log.info(f'Torch backend: Intel IPEX {ipex.__version__}') + log.info(f"Torch backend: Intel IPEX {ipex.__version__}") for device in [ torch.xpu.device(i) for i in range(torch.xpu.device_count()) ]: log.info( - f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + f"Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}" ) else: - log.warning('Torch reports GPU not available') - + log.warning("Torch reports GPU not available") + return int(torch.__version__[0]) except Exception as e: - # log.warning(f'Could not load torch: {e}') + log.error(f"Could not load torch: {e}") return 0 @@ -404,17 +362,19 @@ def check_repo_version(): in the current directory. If the file exists, it reads the release version from the file and logs it. If the file does not exist, it logs a debug message indicating that the release could not be read. """ - if os.path.exists('.release'): + log.debug("Checking repository version...") + if os.path.exists(".release"): try: - with open(os.path.join('./.release'), 'r', encoding='utf8') as file: - release= file.read() - - log.info(f'Kohya_ss GUI version: {release}') + with open(os.path.join("./.release"), "r", encoding="utf8") as file: + release = file.read() + + log.info(f"Kohya_ss GUI version: {release}") except Exception as e: - log.error(f'Could not read release: {e}') + log.error(f"Could not read release: {e}") else: - log.debug('Could not read release...') - + log.debug("Could not read release...") + + # execute git command def git(arg: str, folder: str = None, ignore: bool = False): """ @@ -433,22 +393,31 @@ def git(arg: str, folder: str = None, ignore: bool = False): If set to True, errors will not be logged. Note: - This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master + This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master """ - - git_cmd = os.environ.get('GIT', "git") - result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') + log.debug(f"Running git command: git {arg} in folder: {folder or '.'}") + result = subprocess.run( + ["git", arg], + check=False, + shell=True, + env=os.environ, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=folder or ".", + ) txt = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: - txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") + txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode( + encoding="utf8", errors="ignore" + ) txt = txt.strip() if result.returncode != 0 and not ignore: global errors errors += 1 - log.error(f'Error running git: {folder} / {arg}') - if 'or stash them' in txt: - log.error(f'Local changes detected: check log for details...') - log.debug(f'Git output: {txt}') + log.error(f"Error running git: {folder} / {arg}") + if "or stash them" in txt: + log.error(f"Local changes detected: check log for details...") + log.debug(f"Git output: {txt}") def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False): @@ -473,31 +442,42 @@ def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = Returns: - The output of the pip command as a string, or None if the 'show_stdout' flag is set. """ - # arg = arg.replace('>=', '==') + log.debug(f"Running pip command: {arg}") if not quiet: - log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}') - log.debug(f"Running pip: {arg}") + log.info( + f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}' + ) + pip_cmd = [rf"{sys.executable}", "-m", "pip"] + arg.split(" ") + log.debug(f"Running pip: {pip_cmd}") if show_stdout: - subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ) + subprocess.run(pip_cmd, shell=False, check=False, env=os.environ) else: - result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = subprocess.run( + pip_cmd, + shell=False, + check=False, + env=os.environ, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) txt = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: - txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") + txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode( + encoding="utf8", errors="ignore" + ) txt = txt.strip() if result.returncode != 0 and not ignore: - global errors # pylint: disable=global-statement - errors += 1 - log.error(f'Error running pip: {arg}') - log.debug(f'Pip output: {txt}') + log.error(f"Error running pip: {arg}") + log.error(f"Pip output: {txt}") return txt + def installed(package, friendly: str = None): """ Checks if the specified package(s) are installed with the correct version. This function can handle package specifications with or without version constraints, and can also filter out command-line options and URLs when a 'friendly' string is provided. - + Parameters: - package: A string that specifies one or more packages with optional version constraints. - friendly: An optional string used to provide a cleaner version of the package string @@ -505,43 +485,39 @@ def installed(package, friendly: str = None): Returns: - True if all specified packages are installed with the correct versions, False otherwise. - + Note: This function was adapted from code written by vladimandic. """ - + log.debug(f"Checking if package is installed: {package}") # Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version") - package = re.sub(r'\[.*?\]', '', package) + package = re.sub(r"\[.*?\]", "", package) try: if friendly: # If a 'friendly' version of the package string is provided, split it into components pkgs = friendly.split() - + # Filter out command-line options and URLs from the package specification pkgs = [ - p - for p in package.split() - if not p.startswith('--') and "://" not in p + p for p in package.split() if not p.startswith("--") and "://" not in p ] else: # Split the package string into components, excluding '-' and '=' prefixed items pkgs = [ p for p in package.split() - if not p.startswith('-') and not p.startswith('=') + if not p.startswith("-") and not p.startswith("=") ] # For each package component, extract the package name, excluding any URLs - pkgs = [ - p.split('/')[-1] for p in pkgs - ] + pkgs = [p.split("/")[-1] for p in pkgs] for pkg in pkgs: # Parse the package name and version based on the version specifier used - if '>=' in pkg: - pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] - elif '==' in pkg: - pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] + if ">=" in pkg: + pkg_name, pkg_version = [x.strip() for x in pkg.split(">=")] + elif "==" in pkg: + pkg_name, pkg_version = [x.strip() for x in pkg.split("==")] else: pkg_name, pkg_version = pkg.strip(), None @@ -552,38 +528,41 @@ def installed(package, friendly: str = None): spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) if spec is None: # Try replacing underscores with dashes - spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) + spec = pkg_resources.working_set.by_key.get( + pkg_name.replace("_", "-"), None + ) if spec is not None: # Package is found, check version version = pkg_resources.get_distribution(pkg_name).version - log.debug(f'Package version found: {pkg_name} {version}') + log.debug(f"Package version found: {pkg_name} {version}") if pkg_version is not None: # Verify if the installed version meets the specified constraints - if '>=' in pkg: + if ">=" in pkg: ok = version >= pkg_version else: ok = version == pkg_version if not ok: # Version mismatch, log warning and return False - log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}') + log.warning( + f"Package wrong version: {pkg_name} {version} required {pkg_version}" + ) return False else: # Package not found, log debug message and return False - log.debug(f'Package version not found: {pkg_name}') + log.debug(f"Package version not found: {pkg_name}") return False # All specified packages are installed with the correct versions return True except ModuleNotFoundError: # One or more packages are not installed, log debug message and return False - log.debug(f'Package not installed: {pkgs}') + log.debug(f"Package not installed: {pkgs}") return False - # install package using pip if not already installed def install( package, @@ -595,7 +574,7 @@ def install( """ Installs or upgrades a Python package using pip, with options to ignode errors, reinstall packages, and display outputs. - + Parameters: - package (str): The name of the package to be installed or upgraded. Can include version specifiers. Anything after a '#' in the package name will be ignored. @@ -611,103 +590,98 @@ def install( Returns: None. The function performs operations that affect the environment but does not return any value. - + Note: If `reinstall` is True, it disables any mechanism that allows for skipping installations when the package is already present, forcing a fresh install. """ + log.debug(f"Installing package: {package}") # Remove anything after '#' in the package variable - package = package.split('#')[0].strip() + package = package.split("#")[0].strip() if reinstall: - global quick_allowed # pylint: disable=global-statement + global quick_allowed # pylint: disable=global-statement quick_allowed = False if reinstall or not installed(package, friendly): - pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout) + pip(f"install --upgrade {package}", ignore=ignore, show_stdout=show_stdout) def process_requirements_line(line, show_stdout: bool = False): + log.debug(f"Processing requirements line: {line}") # Remove brackets and their contents from the line using regular expressions # e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 - package_name = re.sub(r'\[.*?\]', '', line) + package_name = re.sub(r"\[.*?\]", "", line) install(line, package_name, show_stdout=show_stdout) -def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False): - if check_no_verify_flag: - log.info(f'Verifying modules installation status from {requirements_file}...') - else: - log.info(f'Installing modules from {requirements_file}...') - with open(requirements_file, 'r', encoding='utf8') as f: - # Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.' - if check_no_verify_flag: - lines = [ - line.strip() - for line in f.readlines() - if line.strip() != '' - and not line.startswith('#') - and line is not None - and 'no_verify' not in line - ] - else: - lines = [ - line.strip() - for line in f.readlines() - if line.strip() != '' - and not line.startswith('#') - and line is not None - ] +def install_requirements( + requirements_file, check_no_verify_flag=False, show_stdout: bool = False +): + """ + Install or verify modules from a requirements file. - # Iterate over each line and install the requirements - for line in lines: - # Check if the line starts with '-r' to include another requirements file - if line.startswith('-r'): - # Get the path to the included requirements file - included_file = line[2:].strip() - # Expand the included requirements file recursively - install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout) - else: - process_requirements_line(line, show_stdout=show_stdout) + Parameters: + - requirements_file (str): Path to the requirements file. + - check_no_verify_flag (bool): If True, verify modules installation status without installing. + - show_stdout (bool): If True, show the standard output of the installation process. + """ + log.debug(f"Installing requirements from file: {requirements_file}") + action = "Verifying" if check_no_verify_flag else "Installing" + log.info(f"{action} modules from {requirements_file}...") + + with open(requirements_file, "r", encoding="utf8") as f: + lines = [ + line.strip() + for line in f.readlines() + if line.strip() and not line.startswith("#") and "no_verify" not in line + ] + + for line in lines: + if line.startswith("-r"): + included_file = line[2:].strip() + log.debug(f"Processing included requirements file: {included_file}") + install_requirements( + included_file, + check_no_verify_flag=check_no_verify_flag, + show_stdout=show_stdout, + ) + else: + process_requirements_line(line, show_stdout=show_stdout) def ensure_base_requirements(): try: - import rich # pylint: disable=unused-import + import rich # pylint: disable=unused-import except ImportError: - install('--upgrade rich', 'rich') - + install("--upgrade rich", "rich") + try: import packaging except ImportError: - install('packaging') + install("packaging") def run_cmd(run_cmd): + """ + Execute a command using subprocess. + """ + log.debug(f"Running command: {run_cmd}") try: - subprocess.run(run_cmd, shell=True, check=False, env=os.environ) + subprocess.run(run_cmd, shell=True, check=True, env=os.environ) + log.debug(f"Command executed successfully: {run_cmd}") except subprocess.CalledProcessError as e: - log.error(f'Error occurred while running command: {run_cmd}') - log.error(f'Error: {e}') - - -def delete_file(file_path): - if os.path.exists(file_path): - os.remove(file_path) - - -def write_to_file(file_path, content): - try: - with open(file_path, 'w') as file: - file.write(content) - except IOError as e: - print(f'Error occurred while writing to file: {file_path}') - print(f'Error: {e}') + log.error(f"Error occurred while running command: {run_cmd}") + log.error(f"Error: {e}") def clear_screen(): - # Check the current operating system to execute the correct clear screen command - if os.name == 'nt': # If the operating system is Windows - os.system('cls') - else: # If the operating system is Linux or Mac - os.system('clear') - + """ + Clear the terminal screen. + """ + log.debug("Attempting to clear the terminal screen") + try: + os.system("cls" if os.name == "nt" else "clear") + log.info("Terminal screen cleared successfully") + except Exception as e: + log.error("Error occurred while clearing the terminal screen") + log.error(f"Error: {e}") diff --git a/setup/setup_linux.py b/setup/setup_linux.py index b206d73f2..ba34dcf1c 100644 --- a/setup/setup_linux.py +++ b/setup/setup_linux.py @@ -19,7 +19,10 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce # Upgrade pip if needed setup_common.install('pip') - setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout) + setup_common.install_requirements_inbulk( + platform_requirements_file, show_stdout=True, + ) + # setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout) if not no_run_accelerate: setup_common.configure_accelerate(run_accelerate=False) @@ -31,10 +34,6 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce exit(1) setup_common.update_submodule() - - # setup_common.clone_or_checkout( - # "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts" - # ) parser = argparse.ArgumentParser() parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file') diff --git a/setup/setup_runpod.py b/setup/setup_runpod.py index e87770620..aadd0b4f2 100644 --- a/setup/setup_runpod.py +++ b/setup/setup_runpod.py @@ -54,7 +54,10 @@ def main_menu(platform_requirements_file): # Upgrade pip if needed setup_common.install('pip') - setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=True) + + setup_common.install_requirements_inbulk( + platform_requirements_file, show_stdout=True, + ) configure_accelerate() diff --git a/setup/setup_windows.py b/setup/setup_windows.py index ccfd957b5..bb8cdbe8e 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -123,12 +123,13 @@ def install_kohya_ss_torch2(headless: bool = False): # ) setup_common.install_requirements_inbulk( - "requirements_pytorch_windows.txt", show_stdout=True, optional_parm="--index-url https://download.pytorch.org/whl/cu118" + "requirements_pytorch_windows.txt", show_stdout=True, + # optional_parm="--index-url https://download.pytorch.org/whl/cu124" ) - setup_common.install_requirements_inbulk( - "requirements_windows.txt", show_stdout=True, upgrade=True - ) + # setup_common.install_requirements_inbulk( + # "requirements_windows.txt", show_stdout=True, upgrade=True + # ) setup_common.run_cmd("accelerate config default") diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index 17c4c58a2..f4029396f 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -5,12 +5,11 @@ import setup_common # Get the absolute path of the current file's directory (Kohua_SS project directory) -project_directory = os.path.dirname(os.path.abspath(__file__)) - -# Check if the "setup" directory is present in the project_directory -if "setup" in project_directory: - # If the "setup" directory is present, move one level up to the parent directory - project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +project_directory = ( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + if "setup" in os.path.dirname(os.path.abspath(__file__)) + else os.path.dirname(os.path.abspath(__file__)) +) # Add the project directory to the beginning of the Python search path sys.path.insert(0, project_directory) @@ -19,115 +18,178 @@ # Set up logging log = setup_logging() +log.debug(f"Project directory set to: {project_directory}") def check_path_with_space(): - # Get the current working directory + """Check if the current working directory contains a space.""" cwd = os.getcwd() - - # Check if the current working directory contains a space + log.debug(f"Current working directory: {cwd}") if " " in cwd: - log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.") - log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.") - log.error("The current working directory is: " + cwd) - exit(1) + # Log an error if the current working directory contains spaces + log.error( + "The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI." + ) + log.error( + "Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again." + ) + log.error(f"The current working directory is: {cwd}") + raise RuntimeError("Invalid path: contains spaces.") -def check_torch(): - # Check for toolkit - if shutil.which('nvidia-smi') is not None or os.path.exists( +def detect_toolkit(): + """Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information.""" + log.debug("Detecting available toolkit...") + # Check for NVIDIA toolkit by looking for nvidia-smi executable + if shutil.which("nvidia-smi") or os.path.exists( os.path.join( - os.environ.get('SystemRoot') or r'C:\Windows', - 'System32', - 'nvidia-smi.exe', + os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe" ) ): - log.info('nVidia toolkit detected') - elif shutil.which('rocminfo') is not None or os.path.exists( - '/opt/rocm/bin/rocminfo' + log.debug("nVidia toolkit detected") + return "nVidia" + # Check for AMD toolkit by looking for rocminfo executable + elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"): + log.debug("AMD toolkit detected") + return "AMD" + # Check for Intel toolkit by looking for SYCL or OneAPI indicators + elif ( + shutil.which("sycl-ls") + or os.environ.get("ONEAPI_ROOT") + or os.path.exists("/opt/intel/oneapi") ): - log.info('AMD toolkit detected') - elif (shutil.which('sycl-ls') is not None - or os.environ.get('ONEAPI_ROOT') is not None - or os.path.exists('/opt/intel/oneapi')): - log.info('Intel OneAPI toolkit detected') + log.debug("Intel toolkit detected") + return "Intel" + # Default to CPU if no toolkit is detected else: - log.info('Using CPU-only Torch') + log.debug("No specific GPU toolkit detected, defaulting to CPU") + return "CPU" + +def check_torch(): + """Check if torch is available and log the relevant information.""" + # Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU) + toolkit = detect_toolkit() + log.info(f"{toolkit} toolkit detected") try: + # Import PyTorch + log.debug("Importing PyTorch...") import torch - try: - # Import IPEX / XPU support - import intel_extension_for_pytorch as ipex - except Exception: - pass - log.info(f'Torch {torch.__version__}') + ipex = None + # Attempt to import Intel Extension for PyTorch if Intel toolkit is detected + if toolkit == "Intel": + try: + log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...") + import intel_extension_for_pytorch as ipex + log.debug("Intel Extension for PyTorch (IPEX) imported successfully") + except ImportError: + log.warning("Intel Extension for PyTorch (IPEX) not found.") + + # Log the PyTorch version + log.info(f"Torch {torch.__version__}") + + # Check if CUDA (NVIDIA GPU) is available if torch.cuda.is_available(): - if torch.version.cuda: - # Log nVidia CUDA and cuDNN versions - log.info( - f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' - ) - elif torch.version.hip: - # Log AMD ROCm HIP version - log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') - else: - log.warning('Unknown Torch backend') - - # Log information about detected GPUs - for device in [ - torch.cuda.device(i) for i in range(torch.cuda.device_count()) - ]: - log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' - ) - # Check if XPU is available + log.debug("CUDA is available, logging CUDA info...") + log_cuda_info(torch) + # Check if XPU (Intel GPU) is available elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # Log Intel IPEX version - log.info(f'Torch backend: Intel IPEX {ipex.__version__}') - for device in [ - torch.xpu.device(i) for i in range(torch.xpu.device_count()) - ]: - log.info( - f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' - ) + log.debug("XPU is available, logging XPU info...") + log_xpu_info(torch, ipex) + # Log a warning if no GPU is available else: - log.warning('Torch reports GPU not available') - + log.warning("Torch reports GPU not available") + + # Return the major version of PyTorch return int(torch.__version__[0]) + except ImportError as e: + # Log an error if PyTorch cannot be loaded + log.error(f"Could not load torch: {e}") + sys.exit(1) except Exception as e: - log.error(f'Could not load torch: {e}') + # Log an unexpected error + log.error(f"Unexpected error while checking torch: {e}") sys.exit(1) - + +def log_cuda_info(torch): + """Log information about CUDA-enabled GPUs.""" + # Log the CUDA and cuDNN versions if available + if torch.version.cuda: + log.info( + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' + ) + # Log the ROCm HIP version if using AMD GPU + elif torch.version.hip: + log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}") + else: + log.warning("Unknown Torch backend") + + # Log information about each detected CUDA-enabled GPU + for device in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(device) + log.info( + f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}" + ) + +def log_xpu_info(torch, ipex): + """Log information about Intel XPU-enabled GPUs.""" + # Log the Intel Extension for PyTorch (IPEX) version if available + if ipex: + log.info(f"Torch backend: Intel IPEX {ipex.__version__}") + # Log information about each detected XPU-enabled GPU + for device in range(torch.xpu.device_count()): + props = torch.xpu.get_device_properties(device) + log.info( + f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}" + ) + def main(): + # Check the repository version to ensure compatibility + log.debug("Checking repository version...") setup_common.check_repo_version() - + # Check if the current path contains spaces, which are not supported + log.debug("Checking if the current path contains spaces...") check_path_with_space() - + # Parse command line arguments + log.debug("Parsing command line arguments...") parser = argparse.ArgumentParser( - description='Validate that requirements are satisfied.' + description="Validate that requirements are satisfied." ) parser.add_argument( - '-r', - '--requirements', - type=str, - help='Path to the requirements file.', + "-r", "--requirements", type=str, help="Path to the requirements file." ) - parser.add_argument('--debug', action='store_true', help='Debug on') + parser.add_argument("--debug", action="store_true", help="Debug on") args = parser.parse_args() - + + # Update git submodules if necessary + log.debug("Updating git submodules...") setup_common.update_submodule() + # Check if PyTorch is installed and log relevant information + log.debug("Checking if PyTorch is installed...") torch_ver = check_torch() - + + # Check if the Python version is compatible + log.debug("Checking Python version...") if not setup_common.check_python_version(): - exit(1) + sys.exit(1) + + # Install required packages from the specified requirements file + requirements_file = args.requirements or "requirements_pytorch_windows.txt" + log.debug(f"Installing requirements from: {requirements_file}") + setup_common.install_requirements_inbulk( + requirements_file, show_stdout=True, + # optional_parm="--index-url https://download.pytorch.org/whl/cu124" + ) - if args.requirements: - setup_common.install_requirements(args.requirements, check_no_verify_flag=True) - else: - setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True) - setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True) + # setup_common.install_requirements(requirements_file, check_no_verify_flag=True) + + # log.debug("Installing additional requirements from: requirements_windows.txt") + # setup_common.install_requirements( + # "requirements_windows.txt", check_no_verify_flag=True + # ) -if __name__ == '__main__': +if __name__ == "__main__": + log.debug("Starting main function...") main() + log.debug("Main function finished.") diff --git a/test/config/TI-AdamW8bit-SDXL.json b/test/config/TI-AdamW8bit-SDXL.json new file mode 100644 index 000000000..cdcb1099f --- /dev/null +++ b/test/config/TI-AdamW8bit-SDXL.json @@ -0,0 +1,125 @@ +{ + "adaptive_noise_scale": 0, + "additional_parameters": "", + "async_upload": false, + "bucket_no_upscale": true, + "bucket_reso_steps": 1, + "cache_latents": true, + "cache_latents_to_disk": false, + "caption_dropout_every_n_epochs": 0, + "caption_dropout_rate": 0.05, + "caption_extension": "", + "clip_skip": 2, + "color_aug": false, + "dataset_config": "", + "dynamo_backend": "no", + "dynamo_mode": "default", + "dynamo_use_dynamic": false, + "dynamo_use_fullgraph": false, + "enable_bucket": true, + "epoch": 8, + "extra_accelerate_launch_args": "", + "flip_aug": false, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "huber_c": 0.1, + "huber_schedule": "snr", + "huggingface_path_in_repo": "", + "huggingface_repo_id": "False", + "huggingface_repo_type": "", + "huggingface_repo_visibility": "", + "huggingface_token": "", + "init_word": "*", + "ip_noise_gamma": 0.1, + "ip_noise_gamma_random_strength": true, + "keep_tokens": 0, + "learning_rate": 0.0001, + "log_config": false, + "log_tracker_config": "", + "log_tracker_name": "", + "log_with": "", + "logging_dir": "./test/logs", + "loss_type": "l2", + "lr_scheduler": "cosine", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": 1, + "lr_scheduler_power": 1, + "lr_scheduler_type": "", + "lr_warmup": 0, + "main_process_port": 0, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": 0, + "max_resolution": "1024,1024", + "max_timestep": 0, + "max_token_length": 75, + "max_train_epochs": 0, + "max_train_steps": 0, + "mem_eff_attn": false, + "metadata_author": "False", + "metadata_description": "", + "metadata_license": "", + "metadata_tags": "", + "metadata_title": "", + "min_bucket_reso": 256, + "min_snr_gamma": 10, + "min_timestep": false, + "mixed_precision": "bf16", + "model_list": "custom", + "multi_gpu": false, + "multires_noise_discount": 0.2, + "multires_noise_iterations": 8, + "no_token_padding": false, + "noise_offset": 0.05, + "noise_offset_random_strength": true, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "num_vectors_per_token": 8, + "optimizer": "AdamW8bit", + "optimizer_args": "", + "output_dir": "./test/output", + "output_name": "TI-Adamw8bit-SDXL", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0", + "prior_loss_weight": 1, + "random_crop": false, + "reg_data_dir": "", + "resume": "", + "resume_from_huggingface": "False", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 20, + "sample_prompts": "a painting of man wearing a gas mask , by darius kawasaki", + "sample_sampler": "euler_a", + "save_as_bool": false, + "save_every_n_epochs": 1, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "save_state_to_huggingface": false, + "scale_v_pred_loss_like_noise_pred": false, + "sdxl": true, + "sdxl_no_half_vae": true, + "seed": 1234, + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "template": "style template", + "token_string": "zxc", + "train_batch_size": 4, + "train_data_dir": "./test/img", + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weights": "", + "xformers": "xformers" +} \ No newline at end of file diff --git a/test/config/dataset-multires.toml b/test/config/dataset-multires.toml new file mode 100644 index 000000000..9cba749c2 --- /dev/null +++ b/test/config/dataset-multires.toml @@ -0,0 +1,40 @@ +[general] +# define common settings here +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" +min_bucket_reso = 64 +max_bucket_reso = 2048 + +[[datasets]] +# define the first resolution here +batch_size = 1 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "./test/img/10_darius kawasaki person" + num_repeats = 10 + +[[datasets]] +# define the second resolution here +batch_size = 1 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "./test/img/10_darius kawasaki person" + num_repeats = 10 + +[[datasets]] +# define the third resolution here +batch_size = 1 +enable_bucket = true +resolution = [512, 512] + + [[datasets.subsets]] + image_dir = "./test/img/10_darius kawasaki person" + num_repeats = 10 \ No newline at end of file diff --git a/test/config/dreambooth-AdamW8bit-toml.json b/test/config/dreambooth-AdamW8bit-toml.json index 82344dee7..69c658666 100644 --- a/test/config/dreambooth-AdamW8bit-toml.json +++ b/test/config/dreambooth-AdamW8bit-toml.json @@ -1,49 +1,75 @@ { "adaptive_noise_scale": 0, "additional_parameters": "", + "async_upload": false, "bucket_no_upscale": true, "bucket_reso_steps": 64, "cache_latents": true, "cache_latents_to_disk": false, - "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_every_n_epochs": 0, "caption_dropout_rate": 0.05, "caption_extension": "", "clip_skip": 2, "color_aug": false, "dataset_config": "./test/config/dataset.toml", + "debiased_estimation_loss": false, + "disable_mmap_load_safetensors": false, + "dynamo_backend": "no", + "dynamo_mode": "default", + "dynamo_use_dynamic": false, + "dynamo_use_fullgraph": false, "enable_bucket": true, "epoch": 1, + "extra_accelerate_launch_args": "", "flip_aug": false, "full_bf16": false, "full_fp16": false, + "fused_backward_pass": false, + "fused_optimizer_groups": 0, "gpu_ids": "", "gradient_accumulation_steps": 1, "gradient_checkpointing": false, + "huber_c": 0.1, + "huber_schedule": "snr", + "huggingface_path_in_repo": "", + "huggingface_repo_id": "", + "huggingface_repo_type": "", + "huggingface_repo_visibility": "", + "huggingface_token": "", "ip_noise_gamma": 0, "ip_noise_gamma_random_strength": false, - "keep_tokens": "0", + "keep_tokens": 0, "learning_rate": 5e-05, "learning_rate_te": 1e-05, "learning_rate_te1": 1e-05, "learning_rate_te2": 1e-05, + "log_config": false, "log_tracker_config": "", "log_tracker_name": "", + "log_with": "", "logging_dir": "./test/logs", + "loss_type": "l2", "lr_scheduler": "constant", - "lr_scheduler_args": "", - "lr_scheduler_num_cycles": "", - "lr_scheduler_power": "", + "lr_scheduler_args": "T_max=100", + "lr_scheduler_num_cycles": 1, + "lr_scheduler_power": 1, + "lr_scheduler_type": "CosineAnnealingLR", "lr_warmup": 0, "main_process_port": 12345, "masked_loss": false, "max_bucket_reso": 2048, - "max_data_loader_n_workers": "0", + "max_data_loader_n_workers": 0, "max_resolution": "512,512", "max_timestep": 1000, - "max_token_length": "75", - "max_train_epochs": "", - "max_train_steps": "", + "max_token_length": 75, + "max_train_epochs": 0, + "max_train_steps": 0, "mem_eff_attn": false, + "metadata_author": "", + "metadata_description": "", + "metadata_license": "", + "metadata_tags": "", + "metadata_title": "", "min_bucket_reso": 256, "min_snr_gamma": 0, "min_timestep": 0, @@ -65,14 +91,16 @@ "output_name": "db-AdamW8bit-toml", "persistent_data_loader_workers": false, "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", - "prior_loss_weight": 1.0, + "prior_loss_weight": 1, "random_crop": false, "reg_data_dir": "", "resume": "", + "resume_from_huggingface": "", "sample_every_n_epochs": 0, "sample_every_n_steps": 25, "sample_prompts": "a painting of a gas mask , by darius kawasaki", "sample_sampler": "euler_a", + "save_as_bool": false, "save_every_n_epochs": 1, "save_every_n_steps": 0, "save_last_n_steps": 0, @@ -81,14 +109,16 @@ "save_precision": "fp16", "save_state": false, "save_state_on_train_end": false, + "save_state_to_huggingface": false, "scale_v_pred_loss_like_noise_pred": false, "sdxl": false, - "seed": "1234", + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": false, + "seed": 1234, "shuffle_caption": false, "stop_text_encoder_training": 0, "train_batch_size": 4, "train_data_dir": "", - "use_wandb": false, "v2": false, "v_parameterization": false, "v_pred_like_loss": 0, diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki.cap deleted file mode 100644 index 5a5dfda1e..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki.cap +++ /dev/null @@ -1 +0,0 @@ -solo,simple background,teeth,grey background,from side,no humans,mask,1other,science fiction,cable,gas mask,tube,steampunk,machine diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_2.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_2.cap deleted file mode 100644 index 25472ac97..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_2.cap +++ /dev/null @@ -1 +0,0 @@ -no humans,what diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_3.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_3.cap deleted file mode 100644 index 4ff2864c0..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_3.cap +++ /dev/null @@ -1 +0,0 @@ -1girl,solo,nude,colored skin,monster,blue skin diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_4.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_4.cap deleted file mode 100644 index 0dcbb2813..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_4.cap +++ /dev/null @@ -1 +0,0 @@ -solo,upper body,horns,from side,no humans,blood,1other diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_5.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_5.cap deleted file mode 100644 index 21cb7ea5c..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_5.cap +++ /dev/null @@ -1 +0,0 @@ -solo,1boy,male focus,mask,instrument,science fiction,realistic,music,gas mask diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_6.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_6.cap deleted file mode 100644 index caa9c38ab..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_6.cap +++ /dev/null @@ -1 +0,0 @@ -solo,no humans,mask,helmet,robot,mecha,1other,science fiction,damaged,gas mask,steampunk diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_7.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_7.cap deleted file mode 100644 index 6984985fc..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_7.cap +++ /dev/null @@ -1 +0,0 @@ -solo,from side,no humans,mask,moon,helmet,portrait,1other,ambiguous gender,gas mask diff --git a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_8.cap b/test/img/10_darius kawasaki person/Dariusz_Zawadzki_8.cap deleted file mode 100644 index 515665b66..000000000 --- a/test/img/10_darius kawasaki person/Dariusz_Zawadzki_8.cap +++ /dev/null @@ -1 +0,0 @@ -outdoors,sky,cloud,no humans,monster,realistic,desert