Skip to content

Commit

Permalink
Add time log and README.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 14, 2022
1 parent 30d9eb6 commit b39f29d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Swift Diffusion

This is a single-file re-implementation of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) model. It includes the models for CLIP text encoder, UNet diffusion model and the decoder model. It also includes PLMS inference implementation. The implementation tries to match the Stable Diffusion outputs layer-by-layer, thus, given the same start point `x_T`, this implementation and Stable Diffusion will output the same image.

## Rationale

This re-implementation serves and an education for me to understand diffusion models. It is also necessary for my follow-up work to enable Stable Diffusion on mobile devices such as iPad / iPhone. Without a Swift re-implementation, doing mobile-focused optimization with Python would be difficult and impossible to ship in App Store. It is possible to do this differently, such as exporting to ONNX runtime and use that as the driver on mobile devices. That does limit what kind of optimizations you can apply though. As you can tell, running models that totals about 8GiB in-memory and 4GiB at-rest with full floating-point precision is not trivial on mobile devices. It might requires some non-conventional optimizations that may not be available through existing frameworks. Using something I am familiar with (a framework I built) would be a good starting point.

## Where We Are?

CLIP text model, UNet diffusion model and the decoder has been ported. The `examples:txt2img` target is useful with some path changesinside `examples/txt2img/main.swift`. Need to port the encoder over to enable `img2img`. Other targets, such as `examples:unet`, `examples:clip`, `examples:autoencoder` are the example programs to convert PyTorch weights to the one s4nnc uses.

## What's Next?

The next on my list is to implement the tokenizer. Thanks to PythonKit, right now, I am using the tokenizer from Hugging Face. After tokenizer implemented, the whole thing should be able to run without Python dependencies.

After that, I should change the convolution layout from NCHW to NHWC. That will enable bunch of optimizations in attention layer, mostly to avoid some of the transpose traffic. I can enable CPU mode either by converting convolution layout to NHWC, or implement NCHW convolution in s4nnc. The latter is long overdue, but doing former would be helpful for performance on CPU.

Right now, at run time, UNet model uses ~1.5GiB memory in additional to its 3.3GiB weights. A big chunk of that 1.5GiB is due to the dot product in attention layer. I already optimized away about 1GiB because previously, softmax doesn't run in-place properly (due to complex reasons relating to aliases and reshapes). I believe this is still a case for PyTorch code because there is no in-place softmax method. That dot product can be split further into smaller batches to save peak memory usage (along the token dimension of k). If these are done correctly, we should be able to reduce UNet memory usage to somewhere around 3.8GiB full floating-point. Another idea I have to further reduce the memory usage is to compress shortcut activations in UNet (these shortcut activations will be saved along downsample path and used in upsample path, thus, occupying for long time). But I am less sure how much memory that can save.

Converting the model to FP16 would save memory footprint instantly, but this will be close-to-the-last thing to do. Just by using FP16, UNet should use around 1.9GiB memory, which is very manageable on mobile devices now. Given that we can unload UNet model and load decoder from disk when it is done, this combined can, hopefully, finally, run stable diffusion on mobile. We can further quantize weights to int8 with the `LLM.int8()` transformers trick: https://arxiv.org/pdf/2208.07339.pdf.

## Is It Comparable?

Right now, I didn't run any specific optimizations. Further, the model loading as of today for s4nnc requires executing the model once, and we have some optimization runs (find the most efficient kernels etc.) that are not saved. That has been said, we can compare the execution time of txt2img from Swift v.s. the one from CompVis (there are more optimized forks available, but going through them to find the best would take time) of the diffusion process + decoding process. The Swift txt2img on GPU took about 20s while the CompVis took about 11s (both with one 2080 Ti). I haven't done full analysis on where the slowness is from, but likely on the GroupNorm operator.
16 changes: 9 additions & 7 deletions examples/txt2img/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,18 @@ func ResBlock(b: Int, outChannels: Int, skipConnection: Bool) -> Model {
let emb = Input()
let inLayerNorm = GroupNorm(axis: 1, groups: 32, epsilon: 1e-5, reduce: [2, 3])
var out = inLayerNorm(x)
out = Swish()(out)
out = out.swish()
let inLayerConv2d = Convolution(
groups: 1, filters: outChannels, filterSize: [3, 3],
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
out = inLayerConv2d(out)
let embLayer = Dense(count: outChannels)
var embOut = Swish()(emb)
var embOut = emb.swish()
embOut = embLayer(embOut).reshaped([b, outChannels, 1, 1])
out = out + embOut
let outLayerNorm = GroupNorm(axis: 1, groups: 32, epsilon: 1e-5, reduce: [2, 3])
out = outLayerNorm(out)
out = Swish()(out)
out = out.swish()
// Dropout if needed in the future (for training).
let outLayerConv2d = Convolution(
groups: 1, filters: outChannels, filterSize: [3, 3],
Expand Down Expand Up @@ -411,7 +411,7 @@ func UNet(batchSize: Int) -> Model {
out = outputBlocks
let outNorm = GroupNorm(axis: 1, groups: 32, epsilon: 1e-5, reduce: [2, 3])
out = outNorm(out)
out = Swish()(out)
out = out.swish()
let outConv2d = Convolution(
groups: 1, filters: 4, filterSize: [3, 3],
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
Expand All @@ -425,14 +425,14 @@ func ResnetBlock(prefix: String, outChannels: Int, shortcut: Bool) -> Model {
let x = Input()
let norm1 = GroupNorm(axis: 1, groups: 32, epsilon: 1e-6, reduce: [2, 3])
var out = norm1(x)
out = Swish()(out)
out = out.swish()
let conv1 = Convolution(
groups: 1, filters: outChannels, filterSize: [3, 3],
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
out = conv1(out)
let norm2 = GroupNorm(axis: 1, groups: 32, epsilon: 1e-6, reduce: [2, 3])
out = norm2(out)
out = Swish()(out)
out = out.swish()
let conv2 = Convolution(
groups: 1, filters: outChannels, filterSize: [3, 3],
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
Expand Down Expand Up @@ -513,7 +513,7 @@ func Decoder(channels: [Int], numRepeat: Int, batchSize: Int, startWidth: Int, s
}
let normOut = GroupNorm(axis: 1, groups: 32, epsilon: 1e-6, reduce: [2, 3])
out = normOut(out)
out = Swish()(out)
out = out.swish()
let convOut = Convolution(
groups: 1, filters: 3, filterSize: [3, 3],
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
Expand Down Expand Up @@ -640,6 +640,7 @@ graph.withNoGrad {
}
let alphasCumprod = model.alphasCumprod
var oldEps = [DynamicGraph.Tensor<Float>]()
let startTime = Date()
// Now do PLMS sampling.
for i in 0..<model.steps {
let timestep = model.timesteps - model.timesteps / model.steps * (i + 1) + 1
Expand Down Expand Up @@ -690,6 +691,7 @@ graph.withNoGrad {
}
let z = 1.0 / scaleFactor * x
let img = decoder(inputs: z)[0].as(of: Float.self).toCPU()
print("Total time \(Date().timeIntervalSince(startTime))")
let image = ccv_dense_matrix_new(512, 512, Int32(CCV_8U | CCV_C3), nil, 0)
// I have better way to copy this out (basically, transpose and then ccv_shift). Doing this just for fun.
for y in 0..<512 {
Expand Down

0 comments on commit b39f29d

Please sign in to comment.