Skip to content

Commit

Permalink
Add a bunch of missing SD3 parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Dec 31, 2024
1 parent 4a741a8 commit 65da590
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
34 changes: 34 additions & 0 deletions kohya_gui/class_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,28 @@ def noise_offset_type_change(
info="Cache text encoder outputs to disk to speed up inference",
interactive=True,
)
with gr.Row():
self.clip_l_dropout_rate = gr.Number(
label="CLIP-L Dropout Rate",
value=self.config.get("sd3.clip_l_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for CLIP-L encoder"
)
self.clip_g_dropout_rate = gr.Number(
label="CLIP-G Dropout Rate",
value=self.config.get("sd3.clip_g_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for CLIP-G encoder"
)
self.t5_dropout_rate = gr.Number(
label="T5 Dropout Rate",
value=self.config.get("sd3.t5_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for T5-XXL encoder"
)
with gr.Row():
self.sd3_fused_backward_pass = gr.Checkbox(
label="Fused Backward Pass",
Expand All @@ -207,6 +229,18 @@ def noise_offset_type_change(
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
)
self.enable_scaled_pos_embed = gr.Checkbox(
label="Enable Scaled Positional Embeddings",
info="Enable scaled positional embeddings in the model.",
value=self.config.get("sd3.enable_scaled_pos_embed", False),
)
self.pos_emb_random_crop_rate = gr.Number(
label="Positional Embedding Random Crop Rate",
value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0),
interactive=True,
minimum=0.0,
info="Random crop rate for positional embeddings"
)

self.sd3_checkbox.change(
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
Expand Down
25 changes: 25 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,18 @@ def save_configuration(
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_g_dropout_rate,
sd3_clip_l,
sd3_clip_l_dropout_rate,
sd3_disable_mmap_load_safetensors,
sd3_enable_scaled_pos_embed,
logit_mean,
logit_std,
mode_scale,
pos_emb_random_crop_rate,
save_clip,
save_t5xxl,
sd3_t5_dropout_rate,
sd3_t5xxl,
t5xxl_device,
t5xxl_dtype,
Expand Down Expand Up @@ -598,13 +603,18 @@ def open_configuration(
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_g_dropout_rate,
sd3_clip_l,
sd3_clip_l_dropout_rate,
sd3_disable_mmap_load_safetensors,
sd3_enable_scaled_pos_embed,
logit_mean,
logit_std,
mode_scale,
pos_emb_random_crop_rate,
save_clip,
save_t5xxl,
sd3_t5_dropout_rate,
sd3_t5xxl,
t5xxl_device,
t5xxl_dtype,
Expand Down Expand Up @@ -921,13 +931,18 @@ def train_model(
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_g_dropout_rate,
sd3_clip_l,
sd3_clip_l_dropout_rate,
sd3_disable_mmap_load_safetensors,
sd3_enable_scaled_pos_embed,
logit_mean,
logit_std,
mode_scale,
pos_emb_random_crop_rate,
save_clip,
save_t5xxl,
sd3_t5_dropout_rate,
sd3_t5xxl,
t5xxl_device,
t5xxl_dtype,
Expand Down Expand Up @@ -1643,12 +1658,17 @@ def train_model(
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"clip_g": clip_g if sd3_checkbox else None,
"clip_g_dropout_rate": clip_g_dropout_rate if sd3_checkbox else None,
"clip_l": clip_l_value,
"clip_l_dropout_rate": sd3_clip_l_dropout_rate if sd3_checkbox else None,
"enable_scaled_pos_embed": sd3_enable_scaled_pos_embed if sd3_checkbox else None,
"logit_mean": logit_mean if sd3_checkbox else None,
"logit_std": logit_std if sd3_checkbox else None,
"mode_scale": mode_scale if sd3_checkbox else None,
"pos_emb_random_crop_rate": pos_emb_random_crop_rate if sd3_checkbox else None,
"save_clip": save_clip if sd3_checkbox else None,
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
"t5_dropout_rate": sd3_t5_dropout_rate if sd3_checkbox else None,
# "t5xxl": see previous assignment above for code
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
Expand Down Expand Up @@ -2886,13 +2906,18 @@ def update_LoRA_settings(
sd3_training.sd3_cache_text_encoder_outputs,
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
sd3_training.clip_g,
sd3_training.clip_g_dropout_rate,
sd3_training.clip_l,
sd3_training.clip_l_dropout_rate,
sd3_training.disable_mmap_load_safetensors,
sd3_training.enable_scaled_pos_embed,
sd3_training.logit_mean,
sd3_training.logit_std,
sd3_training.mode_scale,
sd3_training.pos_emb_random_crop_rate,
sd3_training.save_clip,
sd3_training.save_t5xxl,
sd3_training.t5_dropout_rate,
sd3_training.t5xxl,
sd3_training.t5xxl_device,
sd3_training.t5xxl_dtype,
Expand Down

0 comments on commit 65da590

Please sign in to comment.