-
Notifications
You must be signed in to change notification settings - Fork 964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement euler a scheduler #187
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved. | ||
|
||
import Accelerate | ||
import CoreML | ||
|
||
/// A Scheduler used to compute a de-noised image | ||
/// | ||
/// This inplementation matches: | ||
/// [Hugging Face Diffusers EulerAncestralDiscreteScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py) | ||
/// | ||
/// It is based on the [original k-diffusion implementation by Katherine Crowson](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) | ||
/// Limitations: | ||
/// - Only implemented for Euler A algorithm (not Euler) | ||
/// - Assumes model predicts epsilon | ||
@available(iOS 16.2, macOS 13.1, *) | ||
public final class EulerAncestralDiscreteScheduler: Scheduler { | ||
public let trainStepCount: Int | ||
public let inferenceStepCount: Int | ||
public let betas: [Float] | ||
public let timeSteps: [Int] | ||
public let alphas: [Float] | ||
public let alphasCumProd: [Float] | ||
public let sigmas: [Float] | ||
public let initNoiseSigma: Float | ||
private(set) var randomSource: RandomSource | ||
|
||
public init( | ||
randomSource: RandomSource, | ||
stepCount: Int = 50, | ||
trainStepCount: Int = 1000, | ||
betaSchedule: BetaSchedule = .linear, | ||
betaStart: Float = 0.0001, | ||
betaEnd: Float = 0.02 | ||
) { | ||
self.randomSource = randomSource | ||
self.trainStepCount = trainStepCount | ||
self.inferenceStepCount = stepCount | ||
|
||
switch betaSchedule { | ||
case .linear: | ||
self.betas = linspace(betaStart, betaEnd, trainStepCount) | ||
case .scaledLinear: | ||
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) | ||
} | ||
|
||
self.alphas = betas.map({ 1.0 - $0 }) | ||
|
||
var alphasCumProd = self.alphas | ||
for i in 1..<alphasCumProd.count { | ||
alphasCumProd[i] *= alphasCumProd[i - 1] | ||
} | ||
self.alphasCumProd = alphasCumProd | ||
|
||
var sigmas = vForce.sqrt(vDSP.divide(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd), self.alphasCumProd)) | ||
sigmas.reverse() | ||
sigmas.append(0) | ||
self.sigmas = sigmas | ||
|
||
self.initNoiseSigma = sigmas.max()! | ||
|
||
self.timeSteps = linspace(0, Float(trainStepCount - 1), trainStepCount).reversed().map { Int(round($0)) } | ||
} | ||
|
||
public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample s: MLShapedArray<Float32>) -> MLShapedArray<Float32> { | ||
let stepIndex = timeSteps.firstIndex(of: t)! | ||
let sigma = sigmas[stepIndex] | ||
|
||
// compute predicted original sample (x0) from sigma-scaled predicted noise (for epsilon): | ||
// sample - sigma * output | ||
let predOriginalSample = weightedSum([1.0, Double(-1.0 * sigma)], [s, output]) | ||
|
||
let sigmaFrom = sigmas[stepIndex] | ||
let sigmaTo = sigmas[stepIndex + 1] | ||
let sigmaUp = sqrt(pow(sigmaTo, 2) * (pow(sigmaFrom, 2) - pow(sigmaTo, 2)) / pow(sigmaFrom, 2)) | ||
let sigmaDown = sqrt(pow(sigmaTo, 2) - pow(sigmaUp, 2)) | ||
|
||
// Convert to an ODE derivative: | ||
// derivative = (sample - predOriginalSample) / sigma | ||
// prevSample = sample + derivative * dt | ||
let derivative = weightedSum([Double(1 / sigma), Double(-1 / sigma)], [s, predOriginalSample]) | ||
let dt = sigmaDown - sigma | ||
let prevSample = weightedSum([1.0, Double(dt)], [s, derivative]) | ||
|
||
// Introduce noise | ||
let noise = MLShapedArray<Float32>(converting: randomSource.normalShapedArray(output.shape, mean: 0.0, stdev: Double(initNoiseSigma))) | ||
|
||
return weightedSum([1, Double(sigmaUp)], [prevSample, noise]) // output = prevSample + noise * sigmaUp | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ public enum StableDiffusionScheduler { | |
case pndmScheduler | ||
/// Scheduler that uses a second order DPM-Solver++ algorithm | ||
case dpmSolverMultistepScheduler | ||
/// Scheduler that uses an Euler Ancestral discrete algorithm | ||
case eulerAncestralDiscreteScheduler | ||
} | ||
|
||
/// RNG compatible with StableDiffusionPipeline | ||
|
@@ -160,6 +162,7 @@ public struct StableDiffusionPipeline: ResourceManaging { | |
switch config.schedulerType { | ||
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount) | ||
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount) | ||
case .eulerAncestralDiscreteScheduler: return EulerAncestralDiscreteScheduler(randomSource: randomSource(from: config.rngType, seed: config.seed)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new scheduler does not take There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related: Could you please register this scheduler in the CLI enum as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is not expected. Seems like I missed this. Thanks for pointing out. I'll register it in CLI enum. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, looking forward to it :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've made the changes, but still couldn't seem to generate a proper image. Currently still looking into what's the issue. |
||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the length of
self.timeSteps
adhere toself.inferenceStepCount
instead oftrainStepCount
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. I'll make this change.