-
Notifications
You must be signed in to change notification settings - Fork 942
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Image2Image Encoder Encoder * Scheduler and pipeline * fix scheduler * cli * remove CLI comment * disable dpm multistep solver with image2image * clamp initial timestamp * Store timesteps in reverse order for consistency. * Report actual number of steps. * uint32 * PRComments * remove old initializer * pr comments * change name and add error handling also fix 512 hard coded * Add fix for Jpegs --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
- Loading branch information
Showing
9 changed files
with
517 additions
and
97 deletions.
There are no files selected for viewing
29 changes: 29 additions & 0 deletions
29
swift/StableDiffusion/pipeline/AlphasCumprodCalculation.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. All Rights Reserved. | ||
|
||
import Foundation | ||
|
||
public struct AlphasCumprodCalculation { | ||
public var sqrtAlphasCumprod: Float | ||
public var sqrtOneMinusAlphasCumprod: Float | ||
|
||
public init( | ||
sqrtAlphasCumprod: Float, | ||
sqrtOneMinusAlphasCumprod: Float | ||
) { | ||
self.sqrtAlphasCumprod = sqrtAlphasCumprod | ||
self.sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod | ||
} | ||
|
||
public init( | ||
alphasCumprod: [Float], | ||
timesteps: Int = 1_000, | ||
steps: Int, | ||
strength: Float | ||
) { | ||
let tEnc = Int(strength * Float(steps)) | ||
let initTimestep = min(max(0, timesteps - timesteps / steps * (steps - tEnc) + 1), timesteps - 1) | ||
self.sqrtAlphasCumprod = alphasCumprod[initTimestep].squareRoot() | ||
self.sqrtOneMinusAlphasCumprod = (1 - alphasCumprod[initTimestep]).squareRoot() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. All Rights Reserved. | ||
|
||
import Foundation | ||
import Accelerate | ||
import CoreML | ||
|
||
@available(iOS 16.0, macOS 13.0, *) | ||
extension CGImage { | ||
|
||
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF> | ||
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3> | ||
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3> | ||
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3> | ||
|
||
public enum ShapedArrayError: String, Swift.Error { | ||
case wrongNumberOfChannels | ||
case incorrectFormatsConvertingToShapedArray | ||
case vImageConverterNotInitialized | ||
} | ||
|
||
public static func fromShapedArray(_ array: MLShapedArray<Float32>) throws -> CGImage { | ||
|
||
// array is [N,C,H,W], where C==3 | ||
let channelCount = array.shape[1] | ||
guard channelCount == 3 else { | ||
throw ShapedArrayError.wrongNumberOfChannels | ||
} | ||
|
||
let height = array.shape[2] | ||
let width = array.shape[3] | ||
|
||
// Normalize each channel into a float between 0 and 1.0 | ||
let floatChannels = (0..<channelCount).map { i in | ||
|
||
// Normalized channel output | ||
let cOut = PixelBufferPFx1(width: width, height:height) | ||
|
||
// Reference this channel in the array and normalize | ||
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in | ||
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!), | ||
width: width, height: height, | ||
byteCountPerRow: strides[0]*4) | ||
// Map [-1.0 1.0] -> [0.0 1.0] | ||
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut) | ||
} | ||
return cOut | ||
} | ||
|
||
// Convert to interleaved and then to UInt8 | ||
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels) | ||
let uint8Image = PixelBufferI8x3(width: width, height: height) | ||
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips | ||
|
||
// Convert to uint8x3 to RGB CGImage (no alpha) | ||
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) | ||
let cgImage = uint8Image.makeCGImage(cgImageFormat: | ||
.init(bitsPerComponent: 8, | ||
bitsPerPixel: 3*8, | ||
colorSpace: CGColorSpaceCreateDeviceRGB(), | ||
bitmapInfo: bitmapInfo)!)! | ||
|
||
return cgImage | ||
} | ||
|
||
public var plannerRGBShapedArray: MLShapedArray<Float32> { | ||
get throws { | ||
guard | ||
var sourceFormat = vImage_CGImageFormat(cgImage: self), | ||
var mediumFormat = vImage_CGImageFormat( | ||
bitsPerComponent: 8 * MemoryLayout<UInt8>.size, | ||
bitsPerPixel: 8 * MemoryLayout<UInt8>.size * 4, | ||
colorSpace: CGColorSpaceCreateDeviceRGB(), | ||
bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.first.rawValue)), | ||
let width = vImagePixelCount(exactly: self.width), | ||
let height = vImagePixelCount(exactly: self.height) | ||
else { | ||
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray | ||
} | ||
|
||
var sourceImageBuffer = try vImage_Buffer(cgImage: self) | ||
|
||
var mediumDesination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel) | ||
|
||
let converter = vImageConverter_CreateWithCGImageFormat( | ||
&sourceFormat, | ||
&mediumFormat, | ||
nil, | ||
vImage_Flags(kvImagePrintDiagnosticsToConsole), | ||
nil) | ||
|
||
guard let converter = converter?.takeRetainedValue() else { | ||
throw ShapedArrayError.vImageConverterNotInitialized | ||
} | ||
|
||
vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDesination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole)) | ||
|
||
var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
|
||
var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0] | ||
var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0] | ||
|
||
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero) | ||
|
||
let redData = Data(bytes: destinationR.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size) | ||
let greenData = Data(bytes: destinationG.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size) | ||
let blueData = Data(bytes: destinationB.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size) | ||
|
||
let imageData = redData + greenData + blueData | ||
|
||
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.width, self.height]) | ||
|
||
return shapedArray | ||
} | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. All Rights Reserved. | ||
|
||
import Foundation | ||
import CoreML | ||
|
||
@available(iOS 16.0, macOS 13.1, *) | ||
/// Encoder, currently supports image2image | ||
public struct Encoder: ResourceManaging { | ||
|
||
public enum FeatureName: String { | ||
case sample = "sample" | ||
case diagonalNoise = "diagonal_noise" | ||
case noise = "noise" | ||
case sqrtAlphasCumprod = "sqrt_alphas_cumprod" | ||
case sqrtOneMinusAlphasCumprod = "sqrt_one_minus_alphas_cumprod" | ||
} | ||
|
||
public enum Error: String, Swift.Error { | ||
case latentOutputNotValid | ||
case batchLatentOutputEmpty | ||
case sampleInputShapeNotCorrect | ||
case noiseInputShapeNotCorrect | ||
} | ||
|
||
/// VAE encoder model + post math and adding noise from schedular | ||
var model: ManagedMLModel | ||
|
||
/// Create encoder from Core ML model | ||
/// | ||
/// - Parameters: | ||
/// - url: Location of compiled VAE encoder Core ML model | ||
/// - configuration: configuration to be used when the model is loaded | ||
/// - Returns: An encoder that will lazily load its required resources when needed or requested | ||
public init(modelAt url: URL, configuration: MLModelConfiguration) { | ||
self.model = ManagedMLModel(modelAt: url, configuration: configuration) | ||
} | ||
|
||
/// Ensure the model has been loaded into memory | ||
public func loadResources() throws { | ||
try model.loadResources() | ||
} | ||
|
||
/// Unload the underlying model to free up memory | ||
public func unloadResources() { | ||
model.unloadResources() | ||
} | ||
|
||
/// Prediction queue | ||
let queue = DispatchQueue(label: "encoder.predict") | ||
|
||
/// Batch encode latent samples into images | ||
/// - Parameters: | ||
/// - image: image used for image2image | ||
/// - diagonalNoise: random noise for `DiagonalGaussianDistribution` operation | ||
/// - noise: random noise for initial latent space based on strength argument | ||
/// - alphasCumprodStep: calculations using the scheduler traditionally calculated in the pipeline in pyTorch Diffusers library. | ||
/// - Returns: The encoded latent space as MLShapedArray | ||
public func encode( | ||
image: CGImage, | ||
diagonalNoise: MLShapedArray<Float32>, | ||
noise: MLShapedArray<Float32>, | ||
alphasCumprodStep: AlphasCumprodCalculation | ||
) throws -> MLShapedArray<Float32> { | ||
let sample = try image.plannerRGBShapedArray | ||
let sqrtAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtAlphasCumprod], shape: [1, 1]) | ||
let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1]) | ||
|
||
let dict: [String: Any] = [ | ||
FeatureName.sample.rawValue: MLMultiArray(sample), | ||
FeatureName.diagonalNoise.rawValue: MLMultiArray(diagonalNoise), | ||
FeatureName.noise.rawValue: MLMultiArray(noise), | ||
FeatureName.sqrtAlphasCumprod.rawValue: MLMultiArray(sqrtAlphasCumprod), | ||
FeatureName.sqrtOneMinusAlphasCumprod.rawValue: MLMultiArray(sqrtOneMinusAlphasCumprod), | ||
] | ||
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict) | ||
|
||
let batch = MLArrayBatchProvider(array: [featureProvider]) | ||
|
||
// Batch predict with model | ||
|
||
let results = try queue.sync { | ||
try model.perform { model in | ||
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.sample.rawValue], | ||
let shape = feature.multiArrayConstraint?.shape as? [Int] | ||
{ | ||
guard sample.shape == shape else { | ||
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue` | ||
throw Error.sampleInputShapeNotCorrect | ||
} | ||
} | ||
|
||
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.noise.rawValue], | ||
let shape = feature.multiArrayConstraint?.shape as? [Int] | ||
{ | ||
guard noise.shape == shape else { | ||
throw Error.noiseInputShapeNotCorrect | ||
} | ||
} | ||
|
||
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.diagonalNoise.rawValue], | ||
let shape = feature.multiArrayConstraint?.shape as? [Int] | ||
{ | ||
guard diagonalNoise.shape == shape else { | ||
throw Error.noiseInputShapeNotCorrect | ||
} | ||
} | ||
|
||
return try model.predictions(fromBatch: batch) | ||
} | ||
} | ||
|
||
let batchLatents: [MLShapedArray<Float32>] = try (0..<results.count).compactMap { i in | ||
let result = results.features(at: i) | ||
guard | ||
let outputName = result.featureNames.first, | ||
let output = result.featureValue(for: outputName)?.multiArrayValue | ||
else { | ||
throw Error.latentOutputNotValid | ||
} | ||
return MLShapedArray(output) | ||
} | ||
|
||
guard let latents = batchLatents.first else { | ||
throw Error.batchLatentOutputEmpty | ||
} | ||
|
||
return latents | ||
} | ||
|
||
} |
Oops, something went wrong.