Skip to content

Add random source that matches PyTorch #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import CoreML
/// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c)
///
@available(iOS 16.2, macOS 13.1, *)
struct NumPyRandomSource: RandomNumberGenerator {
struct NumPyRandomSource: RandomNumberGenerator, RandomSource {

struct State {
var key = [UInt32](repeating: 0, count: 624)
Expand Down
6 changes: 6 additions & 0 deletions swift/StableDiffusion/pipeline/RandomSource.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import CoreML

@available(iOS 16.2, macOS 13.1, *)
public protocol RandomSource {
mutating func normalShapedArray(_ shape: [Int], mean: Double, stdev: Double) -> MLShapedArray<Double>
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ extension StableDiffusionPipeline {
public var disableSafety: Bool = false
/// The type of Scheduler to use.
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
/// The type of RNG to use
public var rngType: StableDiffusionRNG = .numpyRNG

/// Given the configuration, what mode will be used for generation
public var mode: Mode {
Expand Down
30 changes: 23 additions & 7 deletions swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ public enum StableDiffusionScheduler {
case dpmSolverMultistepScheduler
}

/// RNG compatible with StableDiffusionPipeline
public enum StableDiffusionRNG {
/// RNG that matches numpy implementation
case numpyRNG
/// RNG that matches PyTorch CPU implementation.
case torchRNG
}

/// A pipeline used to generate image samples from text input using stable diffusion
///
/// This implementation matches:
Expand Down Expand Up @@ -157,7 +165,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
throw Error.startingImageProvidedWithoutEncoder
}

let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, rng: config.rngType, stdev: 1, seed: config.seed)
latents = try noiseTuples.map({
try encoder.encode(
image: startingImage,
Expand All @@ -168,7 +176,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
} else {
timestepStrength = nil
// Generate random latent samples from specified seed
latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
latents = generateLatentSamples(config.imageCount, rng: config.rngType, stdev: stdev, seed: config.seed)
}

// De-noising loop
Expand Down Expand Up @@ -224,11 +232,19 @@ public struct StableDiffusionPipeline: ResourceManaging {
return try decodeToImages(latents, disableSafety: config.disableSafety)
}

func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
private func randomSource(from rng: StableDiffusionRNG, seed: UInt32) -> RandomSource {
switch rng {
case .numpyRNG:
return NumPyRandomSource(seed: seed)
case .torchRNG:
return TorchRandomSource(seed: seed)
}
}

func generateLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
var sampleShape = unet.latentSampleShape
sampleShape[0] = 1

var random = NumPyRandomSource(seed: seed)
var random = randomSource(from: rng, seed: seed)
let samples = (0..<count).map { _ in
MLShapedArray<Float32>(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
Expand All @@ -245,11 +261,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation,
/// but I have seen implementations of pipelines where it is the same.
/// - Returns: An array of tuples of noise values with length of batch size.
func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
func generateImage2ImageLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
var sampleShape = unet.latentSampleShape
sampleShape[0] = 1

var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed))
var random = randomSource(from: rng, seed: seed)
let samples = (0..<count).map { _ in
if diagonalAndLatentNoiseIsSame {
let noise = MLShapedArray<Float32>(
Expand Down
152 changes: 152 additions & 0 deletions swift/StableDiffusion/pipeline/TorchRandomSource.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.

import Foundation
import CoreML

/// A random source consistent with PyTorch
///
/// This implementation matches:
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/DistributionsHelper.h
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/DistributionTemplates.h
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/DistributionKernels.cpp
/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/TransformationHelper.h
///
@available(iOS 16.2, macOS 13.1, *)
struct TorchRandomSource: RandomNumberGenerator, RandomSource {

struct State {
var key = [UInt32](repeating: 0, count: 624)
var pos: Int = 0
var nextGauss: Double? = nil
}

var state: State

/// Initialize with a random seed
///
/// - Parameters
/// - seed: Seed for underlying Mersenne Twister 19937 generator
/// - Returns random source
init(seed: UInt32) {
state = .init()
var s = seed & 0xffff_ffff
for i in 0..<state.key.count {
state.key[i] = s
s = UInt32((UInt64(1_812_433_253) * UInt64(s ^ (s >> 30)) + UInt64(i) + 1) & 0xffff_ffff)
}
state.pos = state.key.count
state.nextGauss = nil
}

/// Generate next UInt32 using fast 32bit Mersenne Twister
mutating func nextUInt32() -> UInt32 {
let n = 624
let m = 397
let matrixA: UInt64 = 0x9908_b0df
let upperMask: UInt32 = 0x8000_0000
let lowerMask: UInt32 = 0x7fff_ffff

var y: UInt32
if state.pos == state.key.count {
for i in 0..<(n - m) {
y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask)
state.key[i] = state.key[i + m] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
}
for i in (n - m)..<(n - 1) {
y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask)
state.key[i] = state.key[i + (m - n)] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
}
y = (state.key[n - 1] & upperMask) | (state.key[0] & lowerMask)
state.key[n - 1] = state.key[m - 1] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
state.pos = 0
}
y = state.key[state.pos]
state.pos += 1

y ^= (y >> 11)
y ^= (y << 7) & 0x9d2c_5680
y ^= (y << 15) & 0xefc6_0000
y ^= (y >> 18)

return y
}

mutating func next() -> UInt64 {
let high = nextUInt32()
let low = nextUInt32()
return (UInt64(high) << 32) | UInt64(low)
}

/// Generate next random double value
mutating func nextDouble() -> Double {
let a = next()
return Double(a & 9_007_199_254_740_991) * (1.0 / 9007199254740992.0)
}

/// Generate next random float value
mutating func nextFloat() -> Float {
let a = nextUInt32()
return Float(a & 16_777_215) * (1.0 / 16777216.0)
}

/// Generate next random value from a standard normal
mutating func nextGauss() -> Double {
if let nextGauss = state.nextGauss {
state.nextGauss = nil
return nextGauss
}
// Box-Muller transform
let u1: Double = nextDouble()
let u2: Double = 1 - nextDouble()
let radius = sqrt(-2.0 * log(u2))
let theta = 2.0 * .pi * u1
state.nextGauss = radius * sin(theta)
return radius * cos(theta)
}

/// Generates an array of random values from a normal distribution with given mean and standard deviation.
/// This simulates torch.randn([1, 4, 64, 64], dtype=torch.float), note that for dtype=torch.double, it
/// will be slightly different.
mutating func normalArray(count: Int, mean: Double = 0.0, stdev: Double = 1.0) -> [Double] {
// If it is smaller than 16 elements, Torch generates from Box-Muller transform directly.
// Note that even if this is used to generate Float, it will use Double underneath.
guard count >= 16 else {
return (0..<count).map { _ in nextGauss() * stdev + mean }
}
// Otherwise, Torch first fill a uniform distribution array, then do Box-Muller
// transformation over this array.
var data = (0..<count).map { _ in Double(nextFloat()) }
for i in stride(from: 0, to: count - 15, by: 16) {
for j in 0..<8 {
let u1 = 1 - data[i + j]
let u2 = data[i + j + 8]
let radius = sqrt(-2.0 * log(u1))
let theta = 2.0 * .pi * u2
data[i + j] = radius * cos(theta) * stdev + mean
data[i + j + 8] = radius * sin(theta) * stdev + mean
}
}
if count % 16 != 0 {
for i in (count - 16)..<count {
data[i] = nextDouble()
}
let i = count - 16
for j in 0..<8 {
let u1 = 1 - data[i + j]
let u2 = data[i + j + 8]
let radius = sqrt(-2.0 * log(u1))
let theta = 2.0 * .pi * u2
data[i + j] = radius * cos(theta) * stdev + mean
data[i + j + 8] = radius * sin(theta) * stdev + mean
}
}
return data
}

/// Generate a shaped array with scalars from a normal distribution with given mean and standard deviation.
mutating func normalShapedArray(_ shape: [Int], mean: Double = 0.0, stdev: Double = 1.0) -> MLShapedArray<Double> {
let count = shape.reduce(1, *)
return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape)
}
}
15 changes: 15 additions & 0 deletions swift/StableDiffusionCLI/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ struct StableDiffusionSample: ParsableCommand {
@Option(help: "Scheduler to use, one of {pndm, dpmpp}")
var scheduler: SchedulerOption = .pndm

@Option(help: "Random number generator to use, one of {numpy, torch}")
var rng: RNGOption = .numpy

@Flag(help: "Disable safety checking")
var disableSafety: Bool = false

Expand Down Expand Up @@ -126,6 +129,7 @@ struct StableDiffusionSample: ParsableCommand {
pipelineConfig.seed = seed
pipelineConfig.guidanceScale = guidanceScale
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
pipelineConfig.rngType = rng.stableDiffusionRNG

let images = try pipeline.generateImages(
configuration: pipelineConfig,
Expand Down Expand Up @@ -250,6 +254,17 @@ enum SchedulerOption: String, ExpressibleByArgument {
}
}

@available(iOS 16.2, macOS 13.1, *)
enum RNGOption: String, ExpressibleByArgument {
case numpy, torch
var stableDiffusionRNG: StableDiffusionRNG {
switch self {
case .numpy: return .numpyRNG
case .torch: return .torchRNG
}
}
}

if #available(iOS 16.2, macOS 13.1, *) {
StableDiffusionSample.main()
} else {
Expand Down