Skip to content

Commit

Permalink
More code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Feb 3, 2024
1 parent 0b217a4 commit 40598ba
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 148 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
## Change History
* 2024/01/27 (v22.6.1)
- Add support for multi-gpu parameters in the GUI under the "Parameters > Advanced" tab.
- Significant rewrite of how parameters are created in the code. I hope I did not break anything in the process... Will make the code easier to update.

* 2024/01/27 (v22.6.0)
- Merge sd-scripts v0.8.3 code update
Expand Down
58 changes: 58 additions & 0 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,16 @@ def run_cmd_advanced_training(**kwargs):
if dataset_repeats:
run_cmd += f' --dataset_repeats="{dataset_repeats}"'

debiased_estimation_loss = kwargs.get("debiased_estimation_loss")
if debiased_estimation_loss:
run_cmd += " --debiased_estimation_loss"

dim_from_weights = kwargs.get("dim_from_weights")
if dim_from_weights and kwargs.get(
"lora_network_weights"
): # Only if lora_network_weights is true
run_cmd += f" --dim_from_weights"

enable_bucket = kwargs.get("enable_bucket")
if enable_bucket:
min_bucket_reso = kwargs.get("min_bucket_reso")
Expand Down Expand Up @@ -792,6 +802,10 @@ def run_cmd_advanced_training(**kwargs):
if logging_dir:
run_cmd += f' --logging_dir="{logging_dir}"'

lora_network_weights = kwargs.get("lora_network_weights")
if lora_network_weights:
run_cmd += f' --lora_network_weights="{lora_network_weights}"'

lr_scheduler = kwargs.get("lr_scheduler")
if lr_scheduler:
run_cmd += f' --lr_scheduler="{lr_scheduler}"'
Expand Down Expand Up @@ -871,6 +885,34 @@ def run_cmd_advanced_training(**kwargs):
if multi_gpu:
run_cmd += " --multi_gpu"

network_alpha = kwargs.get("network_alpha")
if network_alpha:
run_cmd += f' --network_alpha="{network_alpha}"'

network_args = kwargs.get("network_args")
if network_args and len(network_args):
run_cmd += f" --network_args{network_args}"

network_dim = kwargs.get("network_dim")
if network_dim:
run_cmd += f" --network_dim={network_dim}"

network_dropout = kwargs.get("network_dropout")
if network_dropout and network_dropout > 0.0:
run_cmd += f" --network_dropout={network_dropout}"

network_module = kwargs.get("network_module")
if network_module:
run_cmd += f" --network_module={network_module}"

network_train_text_encoder_only = kwargs.get("network_train_text_encoder_only")
if network_train_text_encoder_only:
run_cmd += " --network_train_text_encoder_only"

network_train_unet_only = kwargs.get("network_train_unet_only")
if network_train_unet_only:
run_cmd += " --network_train_unet_only"

no_half_vae = kwargs.get("no_half_vae")
if no_half_vae:
run_cmd += " --no_half_vae"
Expand Down Expand Up @@ -987,6 +1029,10 @@ def run_cmd_advanced_training(**kwargs):
if scale_v_pred_loss_like_noise_pred:
run_cmd += " --scale_v_pred_loss_like_noise_pred"

scale_weight_norms = kwargs.get("scale_weight_norms")
if scale_weight_norms and scale_weight_norms > 0.0:
run_cmd += f' --scale_weight_norms="{scale_weight_norms}"'

seed = kwargs.get("seed")
if seed and seed != "":
run_cmd += f' --seed="{seed}"'
Expand All @@ -999,10 +1045,18 @@ def run_cmd_advanced_training(**kwargs):
if stop_text_encoder_training and stop_text_encoder_training > 0:
run_cmd += f' --stop_text_encoder_training="{stop_text_encoder_training}"'

text_encoder_lr = kwargs.get("text_encoder_lr")
if text_encoder_lr and (float(text_encoder_lr) > 0):
run_cmd += f" --text_encoder_lr={text_encoder_lr}"

train_batch_size = kwargs.get("train_batch_size")
if train_batch_size:
run_cmd += f' --train_batch_size="{train_batch_size}"'

training_comment = kwargs.get("training_comment")
if training_comment and len(training_comment):
run_cmd += f' --training_comment="{training_comment}"'

train_data_dir = kwargs.get("train_data_dir")
if train_data_dir:
run_cmd += f' --train_data_dir="{train_data_dir}"'
Expand All @@ -1011,6 +1065,10 @@ def run_cmd_advanced_training(**kwargs):
if train_text_encoder:
run_cmd += " --train_text_encoder"

unet_lr = kwargs.get("unet_lr")
if unet_lr and (float(unet_lr) > 0):
run_cmd += f" --unet_lr={unet_lr}"

use_wandb = kwargs.get("use_wandb")
if use_wandb:
run_cmd += " --log_with wandb"
Expand Down
Loading

0 comments on commit 40598ba

Please sign in to comment.