Skip to content

Commit

Permalink
Add support for cpu_offload_checkpointing to GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Sep 5, 2024
1 parent 1b3d71f commit a5fb38b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
6 changes: 6 additions & 0 deletions kohya_gui/class_flux1.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def noise_offset_type_change(
info="Train T5-XXL model",
interactive=True,
)
self.cpu_offload_checkpointing = gr.Checkbox(
label="CPU Offload Checkpointing",
value=self.config.get("flux1.cpu_offload_checkpointing", False),
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
interactive=True,
)
with gr.Row():
self.guidance_scale = gr.Number(
label="Guidance Scale",
Expand Down
5 changes: 5 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def save_configuration(
apply_t5_attn_mask,
split_qkv,
train_t5xxl,
cpu_offload_checkpointing,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -489,6 +490,7 @@ def open_configuration(
apply_t5_attn_mask,
split_qkv,
train_t5xxl,
cpu_offload_checkpointing,
training_preset,
):
# Get list of function parameters and their values
Expand Down Expand Up @@ -748,6 +750,7 @@ def train_model(
apply_t5_attn_mask,
split_qkv,
train_t5xxl,
cpu_offload_checkpointing,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -1389,6 +1392,7 @@ def train_model(
"guidance_scale": float(guidance_scale) if flux1_checkbox else None,
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
"cpu_offload_checkpointing": cpu_offload_checkpointing if flux1_checkbox else None,
}

# Given dictionary `config_toml_data`
Expand Down Expand Up @@ -2561,6 +2565,7 @@ def update_LoRA_settings(
flux1_training.apply_t5_attn_mask,
flux1_training.split_qkv,
flux1_training.train_t5xxl,
flux1_training.cpu_offload_checkpointing,
]

configuration.button_open_config.click(
Expand Down

0 comments on commit a5fb38b

Please sign in to comment.