From 65da590eab07ccf00e617e2459bc7f3d61fb0185 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 31 Dec 2024 08:48:48 -0500 Subject: [PATCH] Add a bunch of missing SD3 parameters --- kohya_gui/class_sd3.py | 34 ++++++++++++++++++++++++++++++++++ kohya_gui/lora_gui.py | 25 +++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/kohya_gui/class_sd3.py b/kohya_gui/class_sd3.py index feeaf3c5..5939c260 100644 --- a/kohya_gui/class_sd3.py +++ b/kohya_gui/class_sd3.py @@ -195,6 +195,28 @@ def noise_offset_type_change( info="Cache text encoder outputs to disk to speed up inference", interactive=True, ) + with gr.Row(): + self.clip_l_dropout_rate = gr.Number( + label="CLIP-L Dropout Rate", + value=self.config.get("sd3.clip_l_dropout_rate", 0.0), + interactive=True, + minimum=0.0, + info="Dropout rate for CLIP-L encoder" + ) + self.clip_g_dropout_rate = gr.Number( + label="CLIP-G Dropout Rate", + value=self.config.get("sd3.clip_g_dropout_rate", 0.0), + interactive=True, + minimum=0.0, + info="Dropout rate for CLIP-G encoder" + ) + self.t5_dropout_rate = gr.Number( + label="T5 Dropout Rate", + value=self.config.get("sd3.t5_dropout_rate", 0.0), + interactive=True, + minimum=0.0, + info="Dropout rate for T5-XXL encoder" + ) with gr.Row(): self.sd3_fused_backward_pass = gr.Checkbox( label="Fused Backward Pass", @@ -207,6 +229,18 @@ def noise_offset_type_change( info="Disable memory mapping when loading the model's .safetensors in SDXL.", value=self.config.get("sd3.disable_mmap_load_safetensors", False), ) + self.enable_scaled_pos_embed = gr.Checkbox( + label="Enable Scaled Positional Embeddings", + info="Enable scaled positional embeddings in the model.", + value=self.config.get("sd3.enable_scaled_pos_embed", False), + ) + self.pos_emb_random_crop_rate = gr.Number( + label="Positional Embedding Random Crop Rate", + value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0), + interactive=True, + minimum=0.0, + info="Random crop rate for positional embeddings" + ) self.sd3_checkbox.change( lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox), diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 07264db2..3d4ebfab 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -309,13 +309,18 @@ def save_configuration( sd3_cache_text_encoder_outputs_to_disk, sd3_fused_backward_pass, clip_g, + clip_g_dropout_rate, sd3_clip_l, + sd3_clip_l_dropout_rate, sd3_disable_mmap_load_safetensors, + sd3_enable_scaled_pos_embed, logit_mean, logit_std, mode_scale, + pos_emb_random_crop_rate, save_clip, save_t5xxl, + sd3_t5_dropout_rate, sd3_t5xxl, t5xxl_device, t5xxl_dtype, @@ -598,13 +603,18 @@ def open_configuration( sd3_cache_text_encoder_outputs_to_disk, sd3_fused_backward_pass, clip_g, + clip_g_dropout_rate, sd3_clip_l, + sd3_clip_l_dropout_rate, sd3_disable_mmap_load_safetensors, + sd3_enable_scaled_pos_embed, logit_mean, logit_std, mode_scale, + pos_emb_random_crop_rate, save_clip, save_t5xxl, + sd3_t5_dropout_rate, sd3_t5xxl, t5xxl_device, t5xxl_dtype, @@ -921,13 +931,18 @@ def train_model( sd3_cache_text_encoder_outputs_to_disk, sd3_fused_backward_pass, clip_g, + clip_g_dropout_rate, sd3_clip_l, + sd3_clip_l_dropout_rate, sd3_disable_mmap_load_safetensors, + sd3_enable_scaled_pos_embed, logit_mean, logit_std, mode_scale, + pos_emb_random_crop_rate, save_clip, save_t5xxl, + sd3_t5_dropout_rate, sd3_t5xxl, t5xxl_device, t5xxl_dtype, @@ -1643,12 +1658,17 @@ def train_model( # "cache_text_encoder_outputs": see previous assignment above for code # "cache_text_encoder_outputs_to_disk": see previous assignment above for code "clip_g": clip_g if sd3_checkbox else None, + "clip_g_dropout_rate": clip_g_dropout_rate if sd3_checkbox else None, "clip_l": clip_l_value, + "clip_l_dropout_rate": sd3_clip_l_dropout_rate if sd3_checkbox else None, + "enable_scaled_pos_embed": sd3_enable_scaled_pos_embed if sd3_checkbox else None, "logit_mean": logit_mean if sd3_checkbox else None, "logit_std": logit_std if sd3_checkbox else None, "mode_scale": mode_scale if sd3_checkbox else None, + "pos_emb_random_crop_rate": pos_emb_random_crop_rate if sd3_checkbox else None, "save_clip": save_clip if sd3_checkbox else None, "save_t5xxl": save_t5xxl if sd3_checkbox else None, + "t5_dropout_rate": sd3_t5_dropout_rate if sd3_checkbox else None, # "t5xxl": see previous assignment above for code "t5xxl_device": t5xxl_device if sd3_checkbox else None, "t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None, @@ -2886,13 +2906,18 @@ def update_LoRA_settings( sd3_training.sd3_cache_text_encoder_outputs, sd3_training.sd3_cache_text_encoder_outputs_to_disk, sd3_training.clip_g, + sd3_training.clip_g_dropout_rate, sd3_training.clip_l, + sd3_training.clip_l_dropout_rate, sd3_training.disable_mmap_load_safetensors, + sd3_training.enable_scaled_pos_embed, sd3_training.logit_mean, sd3_training.logit_std, sd3_training.mode_scale, + sd3_training.pos_emb_random_crop_rate, sd3_training.save_clip, sd3_training.save_t5xxl, + sd3_training.t5_dropout_rate, sd3_training.t5xxl, sd3_training.t5xxl_device, sd3_training.t5xxl_dtype,