Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Web] Multistep DPM-solver for web side #3

Merged
merged 1 commit into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/build_site.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ echo "Copy files..."
cp web/stable_diffusion.html site/_includes
cp web/stable_diffusion.js site/dist

cp dist/scheduler_consts.json site/dist
cp dist/scheduler_pndm_consts.json site/dist
cp dist/scheduler_dpm_solver_multistep_consts.json site/dist
cp dist/stable_diffusion_webgpu.wasm site/dist

cp dist/tvmjs_runtime.wasi.js site/dist
Expand Down
3 changes: 2 additions & 1 deletion scripts/rpc_debug_deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ cp web/stable_diffusion.html ${TVM_HOME}/web/dist/www/rpc_plugin.html
cp web/stable_diffusion.js ${TVM_HOME}/web/dist/www/dist/
cp web/local-config.json ${TVM_HOME}/web/dist/www/stable-diffusion-config.json

cp dist/scheduler_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/scheduler_pndm_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/scheduler_dpm_solver_multistep_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/stable_diffusion_webgpu.wasm ${TVM_HOME}/web/dist/www/dist/
cp -rf dist/tokenizers-wasm ${TVM_HOME}/web/dist/www/dist/

Expand Down
5 changes: 4 additions & 1 deletion web/gh-page-config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"schedulerConstUrl": "dist/scheduler_consts.json",
"schedulerConstUrl": [
"dist/scheduler_dpm_solver_multistep_consts.json",
"dist/scheduler_pndm_consts.json"
],
"wasmUrl": "dist/stable_diffusion_webgpu.wasm",
"cacheUrl": "https://huggingface.co/mlc-ai/web-sd/resolve/main/web-sd-shards-v1-5/",
"tokenizer": "openai/clip-vit-large-patch14"
Expand Down
5 changes: 4 additions & 1 deletion web/local-config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"schedulerConstUrl": "dist/scheduler_consts.json",
"schedulerConstUrl": [
"dist/scheduler_dpm_solver_multistep_consts.json",
"dist/scheduler_pndm_consts.json"
],
"wasmUrl": "dist/stable_diffusion_webgpu.wasm",
"cacheUrl": "web-sd-shards-v1-5/",
"tokenizer": "openai/clip-vit-large-patch14"
Expand Down
8 changes: 8 additions & 0 deletions web/stable_diffusion.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
</div>

<div>
Select scheduler -
<select name="scheduler" id="schedulerId">
<option value="0">Multi-step DPM Solver (20 steps)</option>
<option value="1">PNDM (50 steps)</option>
</select>

<br>

Render intermediate steps (may slow down execution) -
<select name="vae-cycle" id="vaeCycle">
<option value="-1">No</option>
Expand Down
168 changes: 135 additions & 33 deletions web/stable_diffusion.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TVMPNDMScheduler {

// prebuild constants
// principle: always detach for class members
// to avoid recyling output scope.
// to avoid recycling output scope.
function loadConsts(output, dtype, input) {
for (let t = 0; t < input.length; ++t) {
output.push(
Expand Down Expand Up @@ -42,7 +42,7 @@ class TVMPNDMScheduler {
for (let i = 0; i < 5; ++i) {
this.schedulerFunc.push(
tvm.detachFromCurrentScope(
vm.getFunction("scheduler_step_" + i.toString())
vm.getFunction("pndm_scheduler_step_" + i.toString())
)
);
}
Expand Down Expand Up @@ -101,6 +101,85 @@ class TVMPNDMScheduler {
}
}

/**
* Wrapper to handle multistep DPM-solver scheduler
*/
class TVMDPMSolverMultistepScheduler {
constructor(schedulerConsts, latentShape, tvm, device, vm) {
this.timestep = [];
this.alpha = [];
this.sigma = [];
this.c0 = [];
this.c1 = [];
this.c2 = [];
this.lastModelOutput = undefined;
this.convertModelOutputFunc = undefined;
this.stepFunc = undefined;
this.tvm = tvm;

// prebuild constants
// principle: always detach for class members
// to avoid recycling output scope.
function loadConsts(output, dtype, input) {
for (let t = 0; t < input.length; ++t) {
output.push(
tvm.detachFromCurrentScope(
tvm.empty([], dtype, device).copyFrom([input[t]])
)
);
}
}
loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]);
loadConsts(this.alpha, "float32", schedulerConsts["alpha"]);
loadConsts(this.sigma, "float32", schedulerConsts["sigma"]);
loadConsts(this.c0, "float32", schedulerConsts["c0"]);
loadConsts(this.c1, "float32", schedulerConsts["c1"]);
loadConsts(this.c2, "float32", schedulerConsts["c2"]);

this.lastModelOutput = this.tvm.detachFromCurrentScope(
this.tvm.empty(latentShape, "float32", device)
)
this.convertModelOutputFunc = tvm.detachFromCurrentScope(
vm.getFunction("dpm_solver_multistep_scheduler_convert_model_output")
)
this.stepFunc = tvm.detachFromCurrentScope(
vm.getFunction("dpm_solver_multistep_scheduler_step")
)
}

dispose() {
for (let t = 0; t < this.timestep.length; ++t) {
this.timestep[t].dispose();
this.alpha[t].dispose();
this.sigma[t].dispose();
this.c0[t].dispose();
this.c1[t].dispose();
this.c2[t].dispose();
}

this.lastModelOutput.dispose();
this.convertModelOutputFunc.dispose();
this.stepFunc.dispose();
}

step(modelOutput, sample, counter) {
modelOutput = this.convertModelOutputFunc(sample, modelOutput, this.alpha[counter], this.sigma[counter])
const prevLatents = this.stepFunc(
sample,
modelOutput,
this.lastModelOutput,
this.c0[counter],
this.c1[counter],
this.c2[counter],
);
this.lastModelOutput = this.tvm.detachFromCurrentScope(
modelOutput
);

return prevLatents;
}
}

class StableDiffusionPipeline {
constructor(tvm, tokenizer, schedulerConsts, cacheMetadata) {
if (cacheMetadata == undefined) {
Expand Down Expand Up @@ -181,10 +260,20 @@ class StableDiffusionPipeline {
* @param prompt Input prompt.
* @param negPrompt Input negative prompt.
* @param progressCallback Callback to check progress.
* @param schedulerId The integer ID of the scheduler to use.
* - 0 for multi-step DPM solver,
* - 1 for PNDM solver.
* @param vaeCycle optionally draw VAE result every cycle iterations.
* @param beginRenderVae Begin rendering VAE after skipping these warmup runs.
*/
async generate(prompt, negPrompt="", progressCallback = undefined, vaeCycle = -1, beginRenderVae = 10) {
async generate(
prompt,
negPrompt = "",
progressCallback = undefined,
schedulerId = 0,
vaeCycle = -1,
beginRenderVae = 10
) {
// Principle: beginScope/endScope in synchronized blocks,
// this helps to recycle intermediate memories
// detach states that needs to go across async boundaries.
Expand All @@ -194,10 +283,21 @@ class StableDiffusionPipeline {
this.tvm.beginScope();
// get latents
const latentShape = [1, 4, 64, 64];
scheduler = new TVMPNDMScheduler(
this.schedulerConsts, latentShape, this.tvm, this.device, this.vm);

var unetNumSteps;
if (schedulerId == 0) {
scheduler = new TVMDPMSolverMultistepScheduler(
this.schedulerConsts[0], latentShape, this.tvm, this.device, this.vm);
unetNumSteps = this.schedulerConsts[0]["num_steps"];
} else {
scheduler = new TVMPNDMScheduler(
this.schedulerConsts[1], latentShape, this.tvm, this.device, this.vm);
unetNumSteps = this.schedulerConsts[1]["num_steps"];
}
const totalNumSteps = unetNumSteps + 2;

if (progressCallback !== undefined) {
progressCallback("clip", 0, 1);
progressCallback("clip", 0, 1, totalNumSteps);
}

const embeddings = this.tvm.withNewScope(() => {
Expand Down Expand Up @@ -229,13 +329,12 @@ class StableDiffusionPipeline {
});
await this.device.sync();
}
const numSteps = 50;
vaeCycle = vaeCycle == -1 ? numSteps: vaeCycle;
vaeCycle = vaeCycle == -1 ? unetNumSteps : vaeCycle;
let lastSync = undefined;

for (let counter = 0; counter < numSteps; ++counter) {
for (let counter = 0; counter < unetNumSteps; ++counter) {
if (progressCallback !== undefined) {
progressCallback("unet", counter, numSteps);
progressCallback("unet", counter, unetNumSteps, totalNumSteps);
}
const timestep = scheduler.timestep[counter];
// recycle noisePred, track latents manually
Expand All @@ -258,8 +357,8 @@ class StableDiffusionPipeline {

// Optionally, we can draw intermediate result of VAE.
if ((counter + 1) % vaeCycle == 0 &&
(counter + 1) != numSteps &&
counter >= beginRenderVae) {
(counter + 1) != unetNumSteps &&
counter >= beginRenderVae) {
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
this.tvm.showImage(this.imageToRGBA(image));
Expand All @@ -273,7 +372,7 @@ class StableDiffusionPipeline {
// Stage 2: VAE and draw image
//-----------------------------
if (progressCallback !== undefined) {
progressCallback("vae", 0, 1);
progressCallback("vae", 0, 1, totalNumSteps);
}
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
Expand All @@ -282,7 +381,7 @@ class StableDiffusionPipeline {
latents.dispose();
await this.device.sync();
if (progressCallback !== undefined) {
progressCallback("vae", 1, 1);
progressCallback("vae", 1, 1, totalNumSteps);
}
}

Expand Down Expand Up @@ -314,7 +413,7 @@ class StableDiffusionInstance {
}

if (document.getElementById("log") !== undefined) {
this.logger = function(message) {
this.logger = function (message) {
console.log(message);
const d = document.createElement("div");
d.innerHTML = message;
Expand Down Expand Up @@ -346,10 +445,10 @@ class StableDiffusionInstance {
} else {
document.getElementById(
"gpu-tracker-label").innerHTML = "This browser env do not support WebGPU";
this.reset();
throw Error("This browser env do not support WebGPU");
this.reset();
throw Error("This browser env do not support WebGPU");
}
} catch(err) {
} catch (err) {
document.getElementById("gpu-tracker-label").innerHTML = (
"Find an error initializing the WebGPU device " + err.toString()
);
Expand Down Expand Up @@ -381,46 +480,48 @@ class StableDiffusionInstance {
throw Error("asyncInitTVM is not called");
}
if (this.pipeline !== undefined) return;
const schedulerConst = await(await fetch(schedulerConstUrl)).json();
var schedulerConst = []
for (let i = 0; i < schedulerConstUrl.length; ++i) {
schedulerConst.push(await (await fetch(schedulerConstUrl[i])).json())
}
const tokenizer = await tvmjsGlobalEnv.getTokenizer(tokenizerName);
this.pipeline = this.tvm.withNewScope(() => {
return new StableDiffusionPipeline(this.tvm, tokenizer, schedulerConst, this.tvm.cacheMetadata);
});
}

/**
* Async intitialize config
* Async initialize config
*/
async #asyncInitConfig() {
if (this.config !== undefined) return;
this.config = await(await fetch("stable-diffusion-config.json")).json();
this.config = await (await fetch("stable-diffusion-config.json")).json();
}

/**
* Function to create progress callback tracker.
* @returns A progress callback tracker.
*/
#getProgressCallback() {
#getProgressCallback() {
const tstart = performance.now();
function progressCallback(stage, counter, numSteps) {
const totalSteps = 50 + 2;
function progressCallback(stage, counter, numSteps, totalNumSteps) {
const timeElapsed = (performance.now() - tstart) / 1000;
let text = "Generating ... at stage " + stage;
if (stage == "unet") {
counter += 1;
text += " step [" + counter + "/" + numSteps + "]"
}
if (stage == "vae") {
counter += 51;
counter = totalNumSteps;
}
text += ", " + Math.ceil(timeElapsed) + " secs elapsed.";
document.getElementById("progress-tracker-label").innerHTML = text;
document.getElementById("progress-tracker-progress").value = (counter / totalSteps) * 100;
document.getElementById("progress-tracker-progress").value = (counter / totalNumSteps) * 100;
}
return progressCallback;
}

/**
/**
* Async initialize instance.
*/
async asyncInit() {
Expand All @@ -442,11 +543,11 @@ class StableDiffusionInstance {
this.tvm = tvmInstance;

this.tvm.beginScope();
this.tvm.registerAsyncServerFunc("generate", async (prompt, vaeCycle) => {
this.tvm.registerAsyncServerFunc("generate", async (prompt, schedulerId, vaeCycle) => {
document.getElementById("inputPrompt").value = prompt;
const negPrompt = "";
document.getElementById("negativePrompt").value = "";
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), vaeCycle);
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle);
});
this.tvm.registerAsyncServerFunc("clearCanvas", async () => {
this.tvm.clearCanvas();
Expand All @@ -470,8 +571,9 @@ class StableDiffusionInstance {
await this.asyncInit();
const prompt = document.getElementById("inputPrompt").value;
const negPrompt = document.getElementById("negativePrompt").value;
const vaeCycle =document.getElementById("vaeCycle").value;
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), vaeCycle);
const schedulerId = document.getElementById("schedulerId").value;
const vaeCycle = document.getElementById("vaeCycle").value;
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle);
} catch (err) {
this.logger("Generate error, " + err.toString());
console.log(err.stack);
Expand All @@ -494,11 +596,11 @@ class StableDiffusionInstance {

localStableDiffusionInst = new StableDiffusionInstance();

tvmjsGlobalEnv.asyncOnGenerate = async function() {
tvmjsGlobalEnv.asyncOnGenerate = async function () {
await localStableDiffusionInst.generate();
};

tvmjsGlobalEnv.asyncOnRPCServerLoad = async function(tvm) {
tvmjsGlobalEnv.asyncOnRPCServerLoad = async function (tvm) {
const inst = new StableDiffusionInstance();
await inst.asyncInitOnRPCServerLoad(tvm);
};
2 changes: 2 additions & 0 deletions web_stable_diffusion/trace/scheduler_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def compute_const_dict() -> Dict[str, List[tvm.nd.NDArray]]:
list_model_output_denom_coeff.append(model_output_denom_coeff.item())

return {
"num_steps": len(timesteps),
"timesteps": timesteps,
"sample_coeff": list_sample_coeff,
"alpha_diff": list_alpha_diff,
Expand Down Expand Up @@ -313,6 +314,7 @@ def compute_const_dict() -> Dict[str, List[tvm.nd.NDArray]]:
list_c2.append(c2.item())

return {
"num_steps": len(timesteps),
"timesteps": timesteps,
"alpha": list_alpha,
"sigma": list_sigma,
Expand Down