Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ async def websocket_handler(request):
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))

@routes.get("/embeddings")
def get_embeddings(self):
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
embed_dir = os.path.join(models_dir, "embeddings")
embeddings = nodes.filter_files_extensions(nodes.recursive_search(embed_dir), nodes.supported_pt_extensions)

return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))

@routes.get("/extensions")
async def get_extensions(request):
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
Expand Down
9 changes: 9 additions & 0 deletions web/scripts/api.js
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class ComfyApi extends EventTarget {
return await resp.json();
}

/**
* Gets a list of embedding names
* @returns An array of script urls to import
*/
async getEmbeddings() {
const resp = await fetch("/embeddings", { cache: "no-store" });
return await resp.json();
}

/**
* Loads node object definitions for the graph
* @returns The node definitions
Expand Down
10 changes: 7 additions & 3 deletions web/scripts/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata } from "./pnginfo.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js";

class ComfyApp {
constructor() {
Expand Down Expand Up @@ -743,8 +743,12 @@ class ComfyApp {
async handleFile(file) {
if (file.type === "image/png") {
const pngInfo = await getPngMetadata(file);
if (pngInfo && pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
if (pngInfo) {
if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.parameters) {
importA1111(this.graph, pngInfo.parameters);
}
}
} else if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
Expand Down
261 changes: 261 additions & 0 deletions web/scripts/pnginfo.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { api } from "./api.js";

export function getPngMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
Expand Down Expand Up @@ -43,3 +45,262 @@ export function getPngMetadata(file) {
reader.readAsArrayBuffer(file);
});
}

export async function importA1111(graph, parameters) {
const p = parameters.lastIndexOf("\nSteps:");
if (p > -1) {
const embeddings = await api.getEmbeddings();
const opts = parameters
.substr(p)
.split(",")
.reduce((p, n) => {
const s = n.split(":");
p[s[0].trim().toLowerCase()] = s[1].trim();
return p;
}, {});
const p2 = parameters.lastIndexOf("\nNegative prompt:", p);
if (p2 > -1) {
let positive = parameters.substr(0, p2).trim();
let negative = parameters.substring(p2 + 18, p).trim();

const ckptNode = LiteGraph.createNode("CheckpointLoaderSimple");
const clipSkipNode = LiteGraph.createNode("CLIPSetLastLayer");
const positiveNode = LiteGraph.createNode("CLIPTextEncode");
const negativeNode = LiteGraph.createNode("CLIPTextEncode");
const samplerNode = LiteGraph.createNode("KSampler");
const imageNode = LiteGraph.createNode("EmptyLatentImage");
const vaeNode = LiteGraph.createNode("VAEDecode");
const vaeLoaderNode = LiteGraph.createNode("VAELoader");
const saveNode = LiteGraph.createNode("SaveImage");
let hrSamplerNode = null;

const ceil64 = (v) => Math.ceil(v / 64) * 64;

function getWidget(node, name) {
return node.widgets.find((w) => w.name === name);
}

function setWidgetValue(node, name, value, isOptionPrefix) {
const w = getWidget(node, name);
if (isOptionPrefix) {
const o = w.options.values.find((w) => w.startsWith(value));
if (o) {
w.value = o;
} else {
console.warn(`Unknown value '${value}' for widget '${name}'`, node);
w.value = value;
}
} else {
w.value = value;
}
}

function createLoraNodes(clipNode, text, prevClip, prevModel) {
const loras = [];
text = text.replace(/<lora:([^:]+:[^>]+)>/g, function (m, c) {
const s = c.split(":");
const weight = parseFloat(s[1]);
if (isNaN(weight)) {
console.warn("Invalid LORA", m);
} else {
loras.push({ name: s[0], weight });
}
return "";
});

for (const l of loras) {
const loraNode = LiteGraph.createNode("LoraLoader");
graph.add(loraNode);
setWidgetValue(loraNode, "lora_name", l.name, true);
setWidgetValue(loraNode, "strength_model", l.weight);
setWidgetValue(loraNode, "strength_clip", l.weight);
prevModel.node.connect(prevModel.index, loraNode, 0);
prevClip.node.connect(prevClip.index, loraNode, 1);
prevModel = { node: loraNode, index: 0 };
prevClip = { node: loraNode, index: 1 };
}

prevClip.node.connect(1, clipNode, 0);
prevModel.node.connect(0, samplerNode, 0);
if (hrSamplerNode) {
prevModel.node.connect(0, hrSamplerNode, 0);
}

return { text, prevModel, prevClip };
}

function replaceEmbeddings(text) {
return text.replaceAll(
new RegExp(
"\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b",
"ig"
),
"embedding:$1"
);
}

function popOpt(name) {
const v = opts[name];
delete opts[name];
return v;
}

graph.clear();
graph.add(ckptNode);
graph.add(clipSkipNode);
graph.add(positiveNode);
graph.add(negativeNode);
graph.add(samplerNode);
graph.add(imageNode);
graph.add(vaeNode);
graph.add(vaeLoaderNode);
graph.add(saveNode);

ckptNode.connect(1, clipSkipNode, 0);
clipSkipNode.connect(0, positiveNode, 0);
clipSkipNode.connect(0, negativeNode, 0);
ckptNode.connect(0, samplerNode, 0);
positiveNode.connect(0, samplerNode, 1);
negativeNode.connect(0, samplerNode, 2);
imageNode.connect(0, samplerNode, 3);
vaeNode.connect(0, saveNode, 0);
samplerNode.connect(0, vaeNode, 0);
vaeLoaderNode.connect(0, vaeNode, 1);

const handlers = {
model(v) {
setWidgetValue(ckptNode, "ckpt_name", v, true);
},
"cfg scale"(v) {
setWidgetValue(samplerNode, "cfg", +v);
},
"clip skip"(v) {
setWidgetValue(clipSkipNode, "stop_at_clip_layer", -v);
},
sampler(v) {
let name = v.toLowerCase().replace("++", "pp").replaceAll(" ", "_");
if (name.includes("karras")) {
name = name.replace("karras", "").replace(/_+$/, "");
setWidgetValue(samplerNode, "scheduler", "karras");
} else {
setWidgetValue(samplerNode, "scheduler", "normal");
}
const w = getWidget(samplerNode, "sampler_name");
const o = w.options.values.find((w) => w === name || w === "sample_" + name);
if (o) {
setWidgetValue(samplerNode, "sampler_name", o);
}
},
size(v) {
const wxh = v.split("x");
const w = ceil64(+wxh[0]);
const h = ceil64(+wxh[1]);
const hrUp = popOpt("hires upscale");
const hrSz = popOpt("hires resize");
let hrMethod = popOpt("hires upscaler");

setWidgetValue(imageNode, "width", w);
setWidgetValue(imageNode, "height", h);

if (hrUp || hrSz) {
let uw, uh;
if (hrUp) {
uw = w * hrUp;
uh = h * hrUp;
} else {
const s = hrSz.split("x");
uw = +s[0];
uh = +s[1];
}

let upscaleNode;
let latentNode;

if (hrMethod.startsWith("Latent")) {
latentNode = upscaleNode = LiteGraph.createNode("LatentUpscale");
graph.add(upscaleNode);
samplerNode.connect(0, upscaleNode, 0);

switch (hrMethod) {
case "Latent (nearest-exact)":
hrMethod = "nearest-exact";
break;
}
setWidgetValue(upscaleNode, "upscale_method", hrMethod, true);
} else {
const decode = LiteGraph.createNode("VAEDecodeTiled");
graph.add(decode);
samplerNode.connect(0, decode, 0);
vaeLoaderNode.connect(0, decode, 1);

const upscaleLoaderNode = LiteGraph.createNode("UpscaleModelLoader");
graph.add(upscaleLoaderNode);
setWidgetValue(upscaleLoaderNode, "model_name", hrMethod, true);

const modelUpscaleNode = LiteGraph.createNode("ImageUpscaleWithModel");
graph.add(modelUpscaleNode);
decode.connect(0, modelUpscaleNode, 1);
upscaleLoaderNode.connect(0, modelUpscaleNode, 0);

upscaleNode = LiteGraph.createNode("ImageScale");
graph.add(upscaleNode);
modelUpscaleNode.connect(0, upscaleNode, 0);

const vaeEncodeNode = (latentNode = LiteGraph.createNode("VAEEncodeTiled"));
graph.add(vaeEncodeNode);
upscaleNode.connect(0, vaeEncodeNode, 0);
vaeLoaderNode.connect(0, vaeEncodeNode, 1);
}

setWidgetValue(upscaleNode, "width", ceil64(uw));
setWidgetValue(upscaleNode, "height", ceil64(uh));

hrSamplerNode = LiteGraph.createNode("KSampler");
graph.add(hrSamplerNode);
ckptNode.connect(0, hrSamplerNode, 0);
positiveNode.connect(0, hrSamplerNode, 1);
negativeNode.connect(0, hrSamplerNode, 2);
latentNode.connect(0, hrSamplerNode, 3);
hrSamplerNode.connect(0, vaeNode, 0);
}
},
steps(v) {
setWidgetValue(samplerNode, "steps", +v);
},
seed(v) {
setWidgetValue(samplerNode, "seed", +v);
},
};

for (const opt in opts) {
if (opt in handlers) {
handlers[opt](popOpt(opt));
}
}

if (hrSamplerNode) {
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value);
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
setWidgetValue(hrSamplerNode, "denoise", +(popOpt("denoising strength") || "1"));
}

let n = createLoraNodes(positiveNode, positive, { node: clipSkipNode, index: 0 }, { node: ckptNode, index: 0 });
positive = n.text;
n = createLoraNodes(negativeNode, negative, n.prevClip, n.prevModel);
negative = n.text;

setWidgetValue(positiveNode, "text", replaceEmbeddings(positive));
setWidgetValue(negativeNode, "text", replaceEmbeddings(negative));

graph.arrange();

for (const opt of ["model hash", "ensd"]) {
delete opts[opt];
}

console.warn("Unhandled parameters:", opts);
}
}
}