From b7df82a6c856fc35098fa1fccd5ccc3e713cd63d Mon Sep 17 00:00:00 2001 From: Maiko Tan Date: Sat, 4 May 2024 16:16:45 +0800 Subject: [PATCH] feat: support sd / stablehorde scheduler option --- data/sd-samplers.json | 26 ++++++++------------------ src/config.ts | 24 ++++++++---------------- src/index.ts | 16 +++++++++++++--- 3 files changed, 29 insertions(+), 37 deletions(-) diff --git a/data/sd-samplers.json b/data/sd-samplers.json index b7f413c..4723807 100644 --- a/data/sd-samplers.json +++ b/data/sd-samplers.json @@ -1,31 +1,21 @@ { - "k_dpmpp_2m_ka": "DPM++ 2M Karras", - "k_dpmpp_sde_ka": "DPM++ SDE Karras", - "k_dpmpp_2m_sde_exp": "DPM++ 2M SDE Exponential", - "k_dpmpp_2m_sde_ka": "DPM++ 2M SDE", + "k_dpmpp_2m": "DPM++ 2M", + "k_dpmpp_sde": "DPM++ SDE", + "k_dpmpp_2m_sde": "DPM++ 2M SDE", + "k_dpmpp_2m_sde_heun": "DPM++ 2M SDE Heun", + "k_dpmpp_2s_a": "DPM++ 2S a", + "k_dpmpp_3m_sde": "DPM++ 3M SDE", "k_euler_a": "Euler a", "k_euler": "Euler", "k_lms": "LMS", "k_heun": "Heun", "k_dpm_2": "DPM2", "k_dpm_2_a": "DPM2 a", - "k_dpmpp_2s_a": "DPM++ 2S a", - "k_dpmpp_2m": "DPM++ 2M", - "k_dpmpp_sde": "DPM++ SDE", - "k_dpmpp_2m_sde_heun": "DPM++ 2M SDE Heun", - "k_dpmpp_2m_sde_heun_ka": "DPM++ 2M SDE Heun Karras", - "k_dpmpp_2m_sde_heun_exp": "DPM++ 2M SDE Heun Exponential", - "k_dpmpp_3m_sde": "DPM++ 3M SDE", - "k_dpmpp_3m_sde_ka": "DPM++ 3M SDE Karras", - "k_dpmpp_3m_sde_exp": "DPM++ 3M SDE Exponential", "k_dpm_fast": "DPM fast", "k_dpm_ad": "DPM adaptive", - "k_lms_ka": "LMS Karras", - "k_dpm_2_ka": "DPM2 Karras", - "k_dpm_2_a_ka": "DPM2 a Karras", - "k_dpmpp_2s_a_ka": "DPM++ 2S a Karras", "restart": "Restart", "ddim": "DDIM", "plms": "PLMS", - "unipc": "UniPC" + "unipc": "UniPC", + "k_lcm": "LCM" } diff --git a/src/config.ts b/src/config.ts index 3d6b7b0..24dce9e 100644 --- a/src/config.ts +++ b/src/config.ts @@ -32,7 +32,11 @@ type Orient = keyof typeof orientMap export const models = Object.keys(modelMap) as Model[] export const orients = Object.keys(orientMap) as Orient[] -export const scheduler = ['native', 'karras', 'exponential', 'polyexponential'] as const +export namespace scheduler { + export const nai = ['native', 'karras', 'exponential', 'polyexponential'] as const + export const sd = ['Automatic', 'Uniform', 'Karras', 'Exponential', 'Polyexponential', 'SGM Uniform'] as const + export const horde = ['karras'] as const +} export namespace sampler { export const nai = { @@ -71,20 +75,6 @@ export namespace sampler { dpmsolver: 'DPM solver', lcm: 'LCM', DDIM: 'DDIM', - k_lms_ka: 'LMS Karras', - k_heun_ka: 'Heun Karras', - k_euler_ka: 'Euler Karras', - k_euler_a_ka: 'Euler a Karras', - k_dpm_2_ka: 'DPM2 Karras', - k_dpm_2_a_ka: 'DPM2 a Karras', - k_dpm_fast_ka: 'DPM fast Karras', - k_dpm_adaptive_ka: 'DPM adaptive Karras', - k_dpmpp_2m_ka: 'DPM++ 2M Karras', - k_dpmpp_2s_a_ka: 'DPM++ 2S a Karras', - k_dpmpp_sde_ka: 'DPM++ SDE Karras', - dpmsolver_ka: 'DPM++ solver Karras', - lcm_ka: 'LCM Karras', - DDIM_ka: 'DDIM Karras', } export function createSchema(map: Dict) { @@ -310,11 +300,13 @@ export const Config = Schema.intersect([ upscaler: Schema.union(upscalers).description('默认的放大算法。').default('Lanczos'), restoreFaces: Schema.boolean().description('是否启用人脸修复。').default(false), hiresFix: Schema.boolean().description('是否启用高分辨率修复。').default(false), + scheduler: Schema.union(scheduler.sd).description('默认的调度器。').default('Automatic'), }), Schema.object({ type: Schema.const('stable-horde').required(), sampler: sampler.createSchema(sampler.horde), model: Schema.union(hordeModels).loose().description('默认的生成模型。'), + scheduler: Schema.union(scheduler.horde).description('默认的调度器。').default('karras'), }), Schema.object({ type: Schema.const('naifu').required(), @@ -330,7 +322,7 @@ export const Config = Schema.intersect([ sampler: sampler.createSchema(sampler.nai3), smea: Schema.boolean().description('默认启用 SMEA。'), smeaDyn: Schema.boolean().description('默认启用 SMEA 采样器的 DYN 变体。'), - scheduler: Schema.union(scheduler).description('默认的调度器。').default('native'), + scheduler: Schema.union(scheduler.nai).description('默认的调度器。').default('native'), }), Schema.object({ sampler: sampler.createSchema(sampler.nai) }), ]), diff --git a/src/index.ts b/src/index.ts index e3ea0bb..6c00453 100644 --- a/src/index.ts +++ b/src/index.ts @@ -109,7 +109,16 @@ export function apply(ctx: Context, config: Config) { .option('hiresFix', '-H', { hidden: () => config.type !== 'sd-webui' }) .option('smea', '-S', { hidden: () => config.model !== 'nai-v3' }) .option('smeaDyn', '-d', { hidden: () => config.model !== 'nai-v3' }) - .option('scheduler', '-C ', { hidden: () => config.model !== 'nai-v3', type: scheduler }) + .option('scheduler', '-C ', { + hidden: () => config.type === 'naifu', + type: ['token', 'login'].includes(config.type) + ? scheduler.nai + : config.type === 'sd-webui' + ? scheduler.sd + : config.type === 'stable-horde' + ? scheduler.horde + : [], + }) .option('decrisper', '-D', { hidden: thirdParty }) .option('undesired', '-u ') .option('noTranslator', '-T', { hidden: () => !ctx.translator || !config.translator }) @@ -347,6 +356,7 @@ export function apply(ctx: Context, config: Config) { case 'sd-webui': { return { sampler_index: sampler.sd[options.sampler], + scheduler: options.scheduler, init_images: image && [image.dataUrl], // sd-webui accepts data URLs with base64 encoded image restore_faces: config.restoreFaces ?? false, enable_hr: options.hiresFix ?? config.hiresFix ?? false, @@ -368,14 +378,14 @@ export function apply(ctx: Context, config: Config) { return { prompt: parameters.prompt, params: { - sampler_name: options.sampler.replace('_ka', ''), + sampler_name: options.sampler, cfg_scale: parameters.scale, denoising_strength: parameters.strength, seed: parameters.seed.toString(), height: parameters.height, width: parameters.width, post_processing: [], - karras: options.sampler.includes('_ka'), + karras: options.scheduler?.toLowerCase() === 'karras', hires_fix: options.hiresFix ?? config.hiresFix ?? false, steps: parameters.steps, n: parameters.n_samples,