From cbaa0ad46f0eda2f3d9bcf8a42d6271e6760e578 Mon Sep 17 00:00:00 2001 From: Nick Senger Date: Wed, 1 Jan 2025 12:34:17 -0800 Subject: [PATCH] UniPC for diffusion sampling (#2684) * feat: Add unipc multistep scheduler * chore: Clippy and formatting * chore: Update comments * chore: Avoid unsafety in float ordering * refactor: Update Scheduler::step mutability requirements * fix: Corrector img2img * chore: Update unipc ref link to latest diffusers release * chore: Deduplicate float ordering * fix: Panic when running with dev profile --- .../examples/stable-diffusion/main.rs | 4 +- .../src/models/stable_diffusion/ddim.rs | 2 +- .../euler_ancestral_discrete.rs | 2 +- .../src/models/stable_diffusion/mod.rs | 1 + .../src/models/stable_diffusion/schedulers.rs | 2 +- .../src/models/stable_diffusion/uni_pc.rs | 1005 +++++++++++++++++ 6 files changed, 1011 insertions(+), 5 deletions(-) create mode 100644 candle-transformers/src/models/stable_diffusion/uni_pc.rs diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index b6585afa32..ebf0bfcb25 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> { ), }; - let scheduler = sd_config.build_scheduler(n_steps)?; + let mut scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; if let Some(seed) = seed { device.set_seed(seed)?; @@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> { }; for idx in 0..num_samples { - let timesteps = scheduler.timesteps(); + let timesteps = scheduler.timesteps().to_vec(); let latents = match &init_latent_dist { Some(init_latent_dist) => { let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index d804ed56c7..ae2b40db1e 100644 --- a/candle-transformers/src/models/stable_diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -127,7 +127,7 @@ impl DDIMScheduler { impl Scheduler for DDIMScheduler { /// Performs a backward step during inference. - fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { + fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs index c27e983a34..250161ccad 100644 --- a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler { } /// Performs a backward step during inference. - fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { + fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { let step_index = self .timesteps .iter() diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 6d89f9cd43..4c685209cb 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -47,6 +47,7 @@ pub mod resnet; pub mod schedulers; pub mod unet_2d; pub mod unet_2d_blocks; +pub mod uni_pc; pub mod utils; pub mod vae; diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index 1d39037f8f..1ce94ca278 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -19,7 +19,7 @@ pub trait Scheduler { fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result; - fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result; + fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result; } /// This represents how beta ranges from its minimum value to the maximum diff --git a/candle-transformers/src/models/stable_diffusion/uni_pc.rs b/candle-transformers/src/models/stable_diffusion/uni_pc.rs new file mode 100644 index 0000000000..c83417f34d --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/uni_pc.rs @@ -0,0 +1,1005 @@ +//! # UniPC Scheduler +//! +//! UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a +//! corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. +//! +//! UniPC is by design model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional +//! sampling. It can also be applied to both noise prediction and data prediction models. Compared with prior +//! methods, UniPC converges faster thanks to the increased order of accuracy. Both quantitative and qualitative +//! results show UniPC can improve sampling quality, especially at very low step counts (5~10). +//! +//! For more information, see the original publication: +//! UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models, W. Zhao et al, 2023. +//! https://arxiv.org/abs/2302.04867 +//! +//! This work is based largely on UniPC implementation from the diffusers python package: +//! https://raw.githubusercontent.com/huggingface/diffusers/e8aacda762e311505ba05ae340af23b149e37af3/src/diffusers/schedulers/scheduling_unipc_multistep.py +use std::collections::HashSet; +use std::ops::Neg; + +use super::schedulers::PredictionType; +use super::{ + schedulers::{Scheduler, SchedulerConfig}, + utils::{interp, linspace}, +}; +use candle::{Error, IndexOp, Result, Tensor}; + +#[derive(Debug, Clone, Copy)] +pub enum SigmaSchedule { + Karras(KarrasSigmaSchedule), + Exponential(ExponentialSigmaSchedule), +} + +impl SigmaSchedule { + fn sigma_t(&self, t: f64) -> f64 { + match self { + Self::Karras(x) => x.sigma_t(t), + Self::Exponential(x) => x.sigma_t(t), + } + } +} + +impl Default for SigmaSchedule { + fn default() -> Self { + Self::Karras(KarrasSigmaSchedule::default()) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct KarrasSigmaSchedule { + pub sigma_min: f64, + pub sigma_max: f64, + pub rho: f64, +} + +impl KarrasSigmaSchedule { + fn sigma_t(&self, t: f64) -> f64 { + let (min_inv_rho, max_inv_rho) = ( + self.sigma_min.powf(1.0 / self.rho), + self.sigma_max.powf(1.0 / self.rho), + ); + + (max_inv_rho + ((1.0 - t) * (min_inv_rho - max_inv_rho))).powf(self.rho) + } +} + +impl Default for KarrasSigmaSchedule { + fn default() -> Self { + Self { + sigma_max: 10.0, + sigma_min: 0.1, + rho: 4.0, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ExponentialSigmaSchedule { + sigma_min: f64, + sigma_max: f64, +} + +impl ExponentialSigmaSchedule { + fn sigma_t(&self, t: f64) -> f64 { + (t * (self.sigma_max.ln() - self.sigma_min.ln()) + self.sigma_min.ln()).exp() + } +} + +impl Default for ExponentialSigmaSchedule { + fn default() -> Self { + Self { + sigma_max: 80.0, + sigma_min: 0.1, + } + } +} + +#[derive(Debug, Default, Clone, Copy)] +pub enum SolverType { + #[default] + Bh1, + Bh2, +} + +#[derive(Debug, Default, Clone, Copy)] +pub enum AlgorithmType { + #[default] + DpmSolverPlusPlus, + SdeDpmSolverPlusPlus, +} + +#[derive(Debug, Default, Clone, Copy)] +pub enum FinalSigmasType { + #[default] + Zero, + SigmaMin, +} + +#[derive(Debug, Clone)] +pub enum TimestepSchedule { + /// Timesteps will be determined by interpolation of sigmas + FromSigmas, + /// Timesteps will be separated by regular intervals + Linspace, +} + +impl TimestepSchedule { + fn timesteps( + &self, + sigma_schedule: &SigmaSchedule, + num_inference_steps: usize, + num_training_steps: usize, + ) -> Result> { + match self { + Self::FromSigmas => { + let sigmas: Tensor = linspace(1., 0., num_inference_steps)? + .to_vec1()? + .into_iter() + .map(|t| sigma_schedule.sigma_t(t)) + .collect::>() + .try_into()?; + let log_sigmas = sigmas.log()?.to_vec1::()?; + let timesteps = interp( + &log_sigmas.iter().copied().rev().collect::>(), + &linspace( + log_sigmas[log_sigmas.len() - 1] - 0.001, + log_sigmas[0] + 0.001, + num_inference_steps, + )? + .to_vec1::()?, + &linspace(0., num_training_steps as f64, num_inference_steps)? + .to_vec1::()?, + ) + .into_iter() + .map(|f| (num_training_steps - 1) - (f as usize)) + .collect::>(); + + Ok(timesteps) + } + + Self::Linspace => { + Ok( + linspace((num_training_steps - 1) as f64, 0., num_inference_steps)? + .to_vec1::()? + .into_iter() + .map(|f| f as usize) + .collect(), + ) + } + } + } +} + +#[derive(Debug, Clone)] +pub enum CorrectorConfiguration { + Disabled, + Enabled { skip_steps: HashSet }, +} + +impl Default for CorrectorConfiguration { + fn default() -> Self { + Self::Enabled { + skip_steps: [0, 1, 2].into_iter().collect(), + } + } +} + +impl CorrectorConfiguration { + pub fn new(disabled_steps: impl IntoIterator) -> Self { + Self::Enabled { + skip_steps: disabled_steps.into_iter().collect(), + } + } +} + +#[derive(Debug, Clone)] +pub struct UniPCSchedulerConfig { + /// Configure the UNIC corrector. By default it is disabled + pub corrector: CorrectorConfiguration, + /// Determines how sigma relates to a given timestep + pub sigma_schedule: SigmaSchedule, + /// Determines the points + pub timestep_schedule: TimestepSchedule, + /// The solver order which can be `1` or higher. It is recommended to use `solver_order=2` for guided + /// sampling, and `solver_order=3` for unconditional sampling. + pub solver_order: usize, + /// Prediction type of the scheduler function + pub prediction_type: PredictionType, + pub num_training_timesteps: usize, + /// Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + /// as Stable Diffusion. + pub thresholding: bool, + /// The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + pub dynamic_thresholding_ratio: f64, + /// The threshold value for dynamic thresholding. + pub sample_max_value: f64, + pub solver_type: SolverType, + /// Whether to use lower-order solvers in the final steps. + pub lower_order_final: bool, +} + +impl Default for UniPCSchedulerConfig { + fn default() -> Self { + Self { + corrector: Default::default(), + timestep_schedule: TimestepSchedule::FromSigmas, + sigma_schedule: SigmaSchedule::Karras(Default::default()), + prediction_type: PredictionType::Epsilon, + num_training_timesteps: 1000, + solver_order: 2, + thresholding: false, + dynamic_thresholding_ratio: 0.995, + sample_max_value: 1.0, + solver_type: SolverType::Bh1, + lower_order_final: true, + } + } +} + +impl SchedulerConfig for UniPCSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result> { + Ok(Box::new(EdmDpmMultistepScheduler::new( + self.clone(), + inference_steps, + )?)) + } +} + +struct State { + model_outputs: Vec>, + lower_order_nums: usize, + order: usize, + last_sample: Option, +} + +impl State { + fn new(solver_order: usize) -> Self { + Self { + model_outputs: vec![None; solver_order], + lower_order_nums: 0, + order: 0, + last_sample: None, + } + } + + fn lower_order_nums(&self) -> usize { + self.lower_order_nums + } + + fn update_lower_order_nums(&mut self, n: usize) { + self.lower_order_nums = n; + } + + fn model_outputs(&self) -> &[Option] { + self.model_outputs.as_slice() + } + + fn update_model_output(&mut self, idx: usize, output: Option) { + self.model_outputs[idx] = output; + } + + fn last_sample(&self) -> Option<&Tensor> { + self.last_sample.as_ref() + } + + fn update_last_sample(&mut self, sample: Tensor) { + let _ = self.last_sample.replace(sample); + } + + fn order(&self) -> usize { + self.order + } + + fn update_order(&mut self, order: usize) { + self.order = order; + } +} + +pub struct EdmDpmMultistepScheduler { + schedule: Schedule, + config: UniPCSchedulerConfig, + state: State, +} + +impl EdmDpmMultistepScheduler { + pub fn new(config: UniPCSchedulerConfig, num_inference_steps: usize) -> Result { + let schedule = Schedule::new( + config.timestep_schedule.clone(), + config.sigma_schedule, + num_inference_steps, + config.num_training_timesteps, + )?; + + Ok(Self { + schedule, + state: State::new(config.solver_order), + config, + }) + } + + fn step_index(&self, timestep: usize) -> usize { + let index_candidates = self + .schedule + .timesteps() + .iter() + .enumerate() + .filter(|(_, t)| (*t == ×tep)) + .map(|(i, _)| i) + .collect::>(); + + match index_candidates.len() { + 0 => 0, + 1 => index_candidates[0], + _ => index_candidates[1], + } + } + + fn timestep(&self, step_idx: usize) -> usize { + self.schedule + .timesteps() + .get(step_idx) + .copied() + .unwrap_or(0) + } + + fn convert_model_output( + &self, + model_output: &Tensor, + sample: &Tensor, + timestep: usize, + ) -> Result { + let (alpha_t, sigma_t) = ( + self.schedule.alpha_t(timestep), + self.schedule.sigma_t(timestep), + ); + + let x0_pred = match self.config.prediction_type { + PredictionType::Epsilon => ((sample - (model_output * sigma_t))? / alpha_t)?, + PredictionType::Sample => model_output.clone(), + PredictionType::VPrediction => ((alpha_t * sample)? - (sigma_t * model_output)?)?, + }; + + if self.config.thresholding { + self.threshold_sample(x0_pred) + } else { + Ok(x0_pred) + } + } + + fn threshold_sample(&self, sample: Tensor) -> Result { + let shape = sample.shape().clone().into_dims(); + let v = sample + .abs()? + .reshape((shape[0], shape[1] * shape[2..].iter().product::()))? + .to_dtype(candle::DType::F64)? + .to_vec2::()?; + let q = stats::Quantile::new(self.config.dynamic_thresholding_ratio) + .with_samples(v.into_iter().flatten()); + let (threshold, max) = (q.quantile().max(self.config.sample_max_value), q.max()); + + sample.clamp(-threshold, threshold)? / (threshold / max).sqrt().min(1.) + } + + fn multistep_uni_p_bh_update(&self, sample: &Tensor, timestep: usize) -> Result { + let step_index = self.step_index(timestep); + let ns = &self.schedule; + let model_outputs = self.state.model_outputs(); + let Some(m0) = &model_outputs[model_outputs.len() - 1] else { + return Err(Error::Msg( + "Expected model output for predictor update".to_string(), + )); + }; + + let (t0, tt) = (timestep, self.timestep(self.step_index(timestep) + 1)); + let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0)); + let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0)); + let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0)); + + let h = lambda_t - lambda_s0; + let device = sample.device(); + + let (mut rks, mut d1s) = (vec![], vec![]); + for i in 1..self.state.order() { + let ti = self.timestep(step_index.saturating_sub(i + 1)); + let Some(mi) = model_outputs + .get(model_outputs.len().saturating_sub(i + 1)) + .into_iter() + .flatten() + .next() + else { + return Err(Error::Msg( + "Expected model output for predictor update".to_string(), + )); + }; + let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti)); + let lambda_si = alpha_si.ln() - sigma_si.ln(); + let rk = (lambda_si - lambda_s0) / h; + rks.push(rk); + d1s.push(((mi - m0)? / rk)?); + } + rks.push(1.0); + let rks = Tensor::new(rks, device)?; + let (mut r, mut b) = (vec![], vec![]); + + let hh = h.neg(); + let h_phi_1 = hh.exp_m1(); + let mut h_phi_k = h_phi_1 / hh - 1.; + let mut factorial_i = 1.; + + let b_h = match self.config.solver_type { + SolverType::Bh1 => hh, + SolverType::Bh2 => hh.exp_m1(), + }; + + for i in 1..self.state.order() + 1 { + r.push(rks.powf(i as f64 - 1.)?); + b.push(h_phi_k * factorial_i / b_h); + factorial_i = i as f64 + 1.; + h_phi_k = h_phi_k / hh - 1. / factorial_i; + } + + let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?); + let (d1s, rhos_p) = match d1s.len() { + 0 => (None, None), + _ => { + let rhos_p = match self.state.order() { + 2 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?, + _ => { + let ((r1, r2), b1) = (r.dims2()?, b.dims1()?); + let inverse = linalg::inverse(&r.i((..(r1 - 1), ..(r2 - 1)))?)?; + let b = b.i(..(b1 - 1))?; + b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())? + } + }; + + (Some(Tensor::stack(&d1s, 1)?), Some(rhos_p)) + } + }; + + let x_t_ = ((sigma_t / sigma_s0 * sample)? - (alpha_t * h_phi_1 * m0)?)?; + if let (Some(d1s), Some(rhos_p)) = (d1s, rhos_p) { + use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral}; + let output_shape = m0.shape().clone(); + let pred_res = TensordotGeneral { + lhs_permutation: Permutation { dims: vec![0] }, + rhs_permutation: Permutation { + dims: vec![1, 0, 2, 3, 4], + }, + tensordot_fixed_position: TensordotFixedPosition { + len_uncontracted_lhs: 1, + len_uncontracted_rhs: output_shape.dims().iter().product::(), + len_contracted_axes: d1s.dim(1)?, + output_shape, + }, + output_permutation: Permutation { + dims: vec![0, 1, 2, 3], + }, + } + .eval(&rhos_p, &d1s)?; + x_t_ - (alpha_t * b_h * pred_res)? + } else { + Ok(x_t_) + } + } + + fn multistep_uni_c_bh_update( + &self, + model_output: &Tensor, + model_outputs: &[Option], + last_sample: &Tensor, + sample: &Tensor, + timestep: usize, + ) -> Result { + let step_index = self.step_index(timestep); + let Some(m0) = model_outputs.last().into_iter().flatten().next() else { + return Err(Error::Msg( + "Expected model output for corrector update".to_string(), + )); + }; + let model_t = model_output; + let (x, _xt) = (last_sample, sample); + + let (t0, tt, ns) = ( + self.timestep(self.step_index(timestep) - 1), + timestep, + &self.schedule, + ); + let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0)); + let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0)); + let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0)); + + let h = lambda_t - lambda_s0; + let device = sample.device(); + + let (mut rks, mut d1s) = (vec![], vec![]); + for i in 1..self.state.order() { + let ti = self.timestep(step_index.saturating_sub(i + 1)); + let Some(mi) = model_outputs + .get(model_outputs.len().saturating_sub(i + 1)) + .into_iter() + .flatten() + .next() + else { + return Err(Error::Msg( + "Expected model output for corrector update".to_string(), + )); + }; + let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti)); + let lambda_si = alpha_si.ln() - sigma_si.ln(); + let rk = (lambda_si - lambda_s0) / h; + rks.push(rk); + d1s.push(((mi - m0)? / rk)?); + } + rks.push(1.0); + let rks = Tensor::new(rks, device)?; + let (mut r, mut b) = (vec![], vec![]); + + let hh = h.neg(); + let h_phi_1 = hh.exp_m1(); + let mut h_phi_k = h_phi_1 / hh - 1.; + let mut factorial_i = 1.; + + let b_h = match self.config.solver_type { + SolverType::Bh1 => hh, + SolverType::Bh2 => hh.exp_m1(), + }; + + for i in 1..self.state.order() + 1 { + r.push(rks.powf(i as f64 - 1.)?); + b.push(h_phi_k * factorial_i / b_h); + factorial_i = i as f64 + 1.; + h_phi_k = h_phi_k / hh - 1. / factorial_i; + } + + let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?); + let d1s = match d1s.len() { + 0 => None, + _ => Some(Tensor::stack(&d1s, 1)?), + }; + let rhos_c = match self.state.order() { + 1 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?, + _ => { + let inverse = linalg::inverse(&r)?; + b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())? + } + }; + + let x_t_ = ((sigma_t / sigma_s0 * x)? - (alpha_t * h_phi_1 * m0)?)?; + let corr_res = d1s + .map(|d1s| { + use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral}; + let output_shape = x_t_.shape().clone(); + TensordotGeneral { + lhs_permutation: Permutation { dims: vec![0] }, + rhs_permutation: Permutation { + dims: vec![1, 0, 2, 3, 4], + }, + tensordot_fixed_position: TensordotFixedPosition { + len_uncontracted_lhs: 1, + len_uncontracted_rhs: output_shape.dims().iter().product::(), + len_contracted_axes: d1s.dim(1)?, + output_shape, + }, + output_permutation: Permutation { + dims: vec![0, 1, 2, 3], + }, + } + .eval(&rhos_c.i(..rhos_c.dims()[0] - 1)?, &d1s) + }) + .unwrap_or_else(|| Tensor::zeros_like(m0))?; + + let d1_t = (model_t - m0)?; + let x_t = (x_t_ + - (alpha_t + * b_h + * (corr_res + rhos_c.i(rhos_c.dims()[0] - 1)?.broadcast_mul(&d1_t)?)?)?)?; + + Ok(x_t) + } +} + +impl Scheduler for EdmDpmMultistepScheduler { + fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { + let step_index = self.step_index(timestep); + let model_output_converted = &self.convert_model_output(model_output, sample, timestep)?; + let sample = match (&self.config.corrector, self.state.last_sample()) { + (CorrectorConfiguration::Enabled { skip_steps: s }, Some(last_sample)) + if !s.contains(&step_index) && step_index > 0 => + { + &self.multistep_uni_c_bh_update( + model_output_converted, + self.state.model_outputs(), + last_sample, + sample, + timestep, + )? + } + (CorrectorConfiguration::Enabled { .. }, _) | (CorrectorConfiguration::Disabled, _) => { + sample + } + }; + + let mut model_outputs = self.state.model_outputs().to_vec(); + for i in 0..self.config.solver_order.saturating_sub(1) { + self.state + .update_model_output(i, model_outputs[i + 1].take()); + } + self.state.update_model_output( + model_outputs.len() - 1, + Some(model_output_converted.clone()), + ); + + let mut this_order = self.config.solver_order; + if self.config.lower_order_final { + this_order = self + .config + .solver_order + .min(self.schedule.timesteps.len() - step_index); + } + self.state + .update_order(this_order.min(self.state.lower_order_nums() + 1)); + + self.state.update_last_sample(sample.clone()); + let prev_sample = self.multistep_uni_p_bh_update(sample, timestep)?; + + let lower_order_nums = self.state.lower_order_nums(); + if lower_order_nums < self.config.solver_order { + self.state.update_lower_order_nums(lower_order_nums + 1); + } + + Ok(prev_sample) + } + + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result { + Ok(sample) + } + + fn timesteps(&self) -> &[usize] { + &self.schedule.timesteps + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result { + let (alpha_t, sigma_t) = ( + self.schedule.alpha_t(timestep), + self.schedule.sigma_t(timestep), + ); + + (alpha_t * original)? + (sigma_t * noise)? + } + + fn init_noise_sigma(&self) -> f64 { + self.schedule.sigma_t(self.schedule.num_training_steps()) + } +} + +#[derive(Debug, Clone)] +struct Schedule { + timesteps: Vec, + num_training_steps: usize, + sigma_schedule: SigmaSchedule, + #[allow(unused)] + timestep_schedule: TimestepSchedule, +} + +impl Schedule { + fn new( + timestep_schedule: TimestepSchedule, + sigma_schedule: SigmaSchedule, + num_inference_steps: usize, + num_training_steps: usize, + ) -> Result { + Ok(Self { + timesteps: timestep_schedule.timesteps( + &sigma_schedule, + num_inference_steps, + num_training_steps, + )?, + timestep_schedule, + sigma_schedule, + num_training_steps, + }) + } + + fn timesteps(&self) -> &[usize] { + &self.timesteps + } + + fn num_training_steps(&self) -> usize { + self.num_training_steps + } + + fn t(&self, step: usize) -> f64 { + (step as f64 + 1.) / self.num_training_steps as f64 + } + + fn alpha_t(&self, t: usize) -> f64 { + (1. / (self.sigma_schedule.sigma_t(self.t(t)).powi(2) + 1.)).sqrt() + } + + fn sigma_t(&self, t: usize) -> f64 { + self.sigma_schedule.sigma_t(self.t(t)) * self.alpha_t(t) + } + + fn lambda_t(&self, t: usize) -> f64 { + self.alpha_t(t).ln() - self.sigma_t(t).ln() + } +} + +mod stats { + //! This is a slightly modified form of the P² quantile implementation from https://github.com/vks/average. + //! Also see: http://www.cs.wustl.edu/~jain/papers/ftp/psqr.pdf + use num_traits::{Float, ToPrimitive}; + + #[derive(Debug, Clone)] + pub struct Quantile { + q: [f64; 5], + n: [i64; 5], + m: [f64; 5], + dm: [f64; 5], + max: Option, + } + + impl Quantile { + pub fn new(p: f64) -> Quantile { + assert!((0. ..=1.).contains(&p)); + Quantile { + q: [0.; 5], + n: [1, 2, 3, 4, 0], + m: [1., 1. + 2. * p, 1. + 4. * p, 3. + 2. * p, 5.], + dm: [0., p / 2., p, (1. + p) / 2., 1.], + max: None, + } + } + + pub fn max(&self) -> f64 { + self.max.unwrap_or(f64::NAN) + } + + fn p(&self) -> f64 { + self.dm[2] + } + + fn parabolic(&self, i: usize, d: f64) -> f64 { + let s = d.round() as i64; + self.q[i] + + d / (self.n[i + 1] - self.n[i - 1]).to_f64().unwrap() + * ((self.n[i] - self.n[i - 1] + s).to_f64().unwrap() + * (self.q[i + 1] - self.q[i]) + / (self.n[i + 1] - self.n[i]).to_f64().unwrap() + + (self.n[i + 1] - self.n[i] - s).to_f64().unwrap() + * (self.q[i] - self.q[i - 1]) + / (self.n[i] - self.n[i - 1]).to_f64().unwrap()) + } + + fn linear(&self, i: usize, d: f64) -> f64 { + let sum = if d < 0. { i - 1 } else { i + 1 }; + self.q[i] + d * (self.q[sum] - self.q[i]) / (self.n[sum] - self.n[i]).to_f64().unwrap() + } + + pub fn quantile(&self) -> f64 { + if self.len() >= 5 { + return self.q[2]; + } + + if self.is_empty() { + return f64::NAN; + } + let mut heights: [f64; 4] = [self.q[0], self.q[1], self.q[2], self.q[3]]; + let len = self.len() as usize; + debug_assert!(len < 5); + sort_floats(&mut heights[..len]); + let desired_index = (len as f64) * self.p() - 1.; + let mut index = desired_index.ceil(); + if desired_index == index && index >= 0. { + let index = index.round() as usize; + debug_assert!(index < 5); + if index < len - 1 { + return 0.5 * self.q[index] + 0.5 * self.q[index + 1]; + } + } + index = index.max(0.); + let mut index = index.round() as usize; + debug_assert!(index < 5); + index = index.min(len - 1); + self.q[index] + } + + fn len(&self) -> u64 { + self.n[4] as u64 + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn add(&mut self, x: f64) { + self.max = self.max.map(|y| y.max(x)).or(Some(x)); + + if self.n[4] < 5 { + self.q[self.n[4] as usize] = x; + self.n[4] += 1; + if self.n[4] == 5 { + sort_floats(&mut self.q); + } + return; + } + + let mut k: usize; + if x < self.q[0] { + self.q[0] = x; + k = 0; + } else { + k = 4; + for i in 1..5 { + if x < self.q[i] { + k = i; + break; + } + } + if self.q[4] < x { + self.q[4] = x; + } + }; + + for i in k..5 { + self.n[i] += 1; + } + for i in 0..5 { + self.m[i] += self.dm[i]; + } + + for i in 1..4 { + let d = self.m[i] - self.n[i].to_f64().unwrap(); + if d >= 1. && self.n[i + 1] - self.n[i] > 1 + || d <= -1. && self.n[i - 1] - self.n[i] < -1 + { + let d = Float::signum(d); + let q_new = self.parabolic(i, d); + if self.q[i - 1] < q_new && q_new < self.q[i + 1] { + self.q[i] = q_new; + } else { + self.q[i] = self.linear(i, d); + } + let delta = d.round() as i64; + debug_assert_eq!(delta.abs(), 1); + self.n[i] += delta; + } + } + } + + pub fn with_samples(mut self, samples: impl IntoIterator) -> Self { + for sample in samples { + self.add(sample); + } + + self + } + } + + fn sort_floats(v: &mut [f64]) { + v.sort_unstable_by(|a, b| a.total_cmp(b)); + } +} + +mod linalg { + use candle::{IndexOp, Result, Shape, Tensor}; + + pub fn inverse(m: &Tensor) -> Result { + adjoint(m)? / determinant(m)?.to_scalar::()? + } + + pub fn adjoint(m: &Tensor) -> Result { + cofactor(m)?.transpose(0, 1) + } + + pub fn cofactor(m: &Tensor) -> Result { + let s = m.shape().dim(0)?; + if s == 2 { + let mut v = vec![]; + for i in 0..2 { + let mut x = vec![]; + for j in 0..2 { + x.push((m.i((i, j))? * (-1.0f64).powi(i as i32 + j as i32))?) + } + v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?); + } + return Tensor::stack(&v, 1)?.squeeze(0); + } + + let minors = minors(m)?; + let mut v = vec![]; + for i in 0..s { + let mut x = vec![]; + for j in 0..s { + let det = (determinant(&minors.i((i, j))?)? + * ((-1.0f64).powi(i as i32) * (-1.0f64).powi(j as i32)))?; + x.push(det); + } + v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?); + } + + Tensor::stack(&v, 1)?.squeeze(0) + } + + pub fn determinant(m: &Tensor) -> Result { + let s = m.shape().dim(0)?; + if s == 2 { + return (m.i((0, 0))? * m.i((1, 1))?)? - (m.i((0, 1))? * m.i((1, 0))?); + } + + let cofactor = cofactor(m)?; + let m0 = m.i((0, 0))?; + let det = (0..s) + .map(|i| (m.i((0, i))? * cofactor.i((0, i))?)) + .try_fold(m0.zeros_like()?, |acc, cur| (acc + cur?))?; + + Ok(det) + } + + pub fn minors(m: &Tensor) -> Result { + let s = m.shape().dim(0)?; + if s == 1 { + return m.i((0, 0)); + } + + let mut v = vec![]; + for i in 0..s { + let msub = Tensor::cat(&[m.i((..i, ..))?, m.i(((i + 1).., ..))?], 0)?; + let mut x = vec![]; + for j in 0..s { + let t = Tensor::cat(&[msub.i((.., ..j))?, msub.i((.., (j + 1)..))?], 1)?; + x.push(t); + } + v.push(Tensor::stack(&x, 0)?.unsqueeze(0)?); + } + + Tensor::stack(&v, 1)?.squeeze(0) + } + + #[derive(Debug)] + pub struct TensordotGeneral { + pub lhs_permutation: Permutation, + pub rhs_permutation: Permutation, + pub tensordot_fixed_position: TensordotFixedPosition, + pub output_permutation: Permutation, + } + + impl TensordotGeneral { + pub fn eval(&self, lhs: &Tensor, rhs: &Tensor) -> Result { + let permuted_lhs = self.lhs_permutation.eval(lhs)?; + let permuted_rhs = self.rhs_permutation.eval(rhs)?; + let tensordotted = self + .tensordot_fixed_position + .eval(&permuted_lhs, &permuted_rhs)?; + self.output_permutation.eval(&tensordotted) + } + } + + #[derive(Debug)] + pub struct TensordotFixedPosition { + pub len_uncontracted_lhs: usize, + pub len_uncontracted_rhs: usize, + pub len_contracted_axes: usize, + pub output_shape: Shape, + } + + impl TensordotFixedPosition { + fn eval(&self, lhs: &Tensor, rhs: &Tensor) -> Result { + let lhs_view = lhs.reshape((self.len_uncontracted_lhs, self.len_contracted_axes))?; + let rhs_view = rhs.reshape((self.len_contracted_axes, self.len_uncontracted_rhs))?; + + lhs_view.matmul(&rhs_view)?.reshape(&self.output_shape) + } + } + + #[derive(Debug)] + pub struct Permutation { + pub dims: Vec, + } + + impl Permutation { + fn eval(&self, tensor: &Tensor) -> Result { + tensor.permute(self.dims.as_slice()) + } + } +}