Skip to content

Commit

Permalink
Add support for Blocks to train
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Sep 16, 2024
1 parent d24fae1 commit 416ef0e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
21 changes: 21 additions & 0 deletions kohya_gui/class_flux1.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,27 @@ def noise_offset_type_change(
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)

with gr.Accordion(
"Blocks to train",
open=True,
visible=False if finetuning else True,
elem_classes=["flux1_blocks_to_train_background"],
):
with gr.Row():
self.train_double_block_indices = gr.Textbox(
label="train_double_block_indices",
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of double blocks is 19.",
value=self.config.get("flux1.train_double_block_indices", "all"),
interactive=True,
)
self.train_single_block_indices = gr.Textbox(
label="train_single_block_indices",
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of single blocks is 38.",
value=self.config.get("flux1.train_single_block_indices", "all"),
interactive=True,
)

with gr.Accordion(
"Rank for layers",
open=False,
Expand Down
10 changes: 10 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def save_configuration(
txt_mod_dim,
single_mod_dim,
in_dims,
train_double_block_indices,
train_single_block_indices,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -554,6 +556,8 @@ def open_configuration(
txt_mod_dim,
single_mod_dim,
in_dims,
train_double_block_indices,
train_single_block_indices,

##
training_preset,
Expand Down Expand Up @@ -849,6 +853,8 @@ def train_model(
txt_mod_dim,
single_mod_dim,
in_dims,
train_double_block_indices,
train_single_block_indices,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -1536,6 +1542,8 @@ def train_model(
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
"cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None,
"train_double_block_indices": train_double_block_indices if flux1_checkbox else None,
"train_single_block_indices": train_single_block_indices if flux1_checkbox else None,
}

# Given dictionary `config_toml_data`
Expand Down Expand Up @@ -2733,6 +2741,8 @@ def update_LoRA_settings(
flux1_training.txt_mod_dim,
flux1_training.single_mod_dim,
flux1_training.in_dims,
flux1_training.train_double_block_indices,
flux1_training.train_single_block_indices,
]

configuration.button_open_config.click(
Expand Down
2 changes: 1 addition & 1 deletion sd-scripts

0 comments on commit 416ef0e

Please sign in to comment.