Skip to content

Commit

Permalink
UniPC for diffusion sampling (#2684)
Browse files Browse the repository at this point in the history
* feat: Add unipc multistep scheduler

* chore: Clippy and formatting

* chore: Update comments

* chore: Avoid unsafety in float ordering

* refactor: Update Scheduler::step mutability requirements

* fix: Corrector img2img

* chore: Update unipc ref link to latest diffusers release

* chore: Deduplicate float ordering

* fix: Panic when running with dev profile
  • Loading branch information
nicksenger authored Jan 1, 2025
1 parent b12c7c2 commit cbaa0ad
Show file tree
Hide file tree
Showing 6 changed files with 1,011 additions and 5 deletions.
4 changes: 2 additions & 2 deletions candle-examples/examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
),
};

let scheduler = sd_config.build_scheduler(n_steps)?;
let mut scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
Expand Down Expand Up @@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
};

for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let timesteps = scheduler.timesteps().to_vec();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/stable_diffusion/ddim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl DDIMScheduler {

impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
}

/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/stable_diffusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod uni_pc;
pub mod utils;
pub mod vae;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub trait Scheduler {

fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;

fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}

/// This represents how beta ranges from its minimum value to the maximum
Expand Down
Loading

0 comments on commit cbaa0ad

Please sign in to comment.