Skip to content

Add CogView4 model support #7770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
51717f6
Add CogView4 model probing.
RyanJDick Mar 5, 2025
f21f505
Add CogView4TextEncoderInvocation
RyanJDick Mar 5, 2025
171afc7
Bump diffusers to dev version with CogView4 support.
RyanJDick Mar 5, 2025
6c90cad
WIP - CogView4DenoiseInvocation.
RyanJDick Mar 5, 2025
0a6a2c5
Simplify CogView4 timestep schedule initialization.
RyanJDick Mar 6, 2025
88aa8e6
Completed first pass of CogView4Denoise.
RyanJDick Mar 6, 2025
507d34c
Add CogView4LatentsToImageInvocation.
RyanJDick Mar 6, 2025
4e778fb
Require the cogview4 height/width are multiples of 32. This requireme…
RyanJDick Mar 6, 2025
2b1466e
Add CogView4ModelLoaderInvocation. (Not wired up with frontend yet.)
RyanJDick Mar 6, 2025
2bfe80e
typegen
RyanJDick Mar 6, 2025
cb38fcb
Fix typo in BaseModelTypo.CogView4.
RyanJDick Mar 6, 2025
cdb1df6
Add CogView4 to frontend.
RyanJDick Mar 6, 2025
999a1d0
Update CogView4 starter model entry with approximate bundle size.
RyanJDick Mar 6, 2025
7d481f7
Add CogView4 model loader. And various other fixes to get a CogView4 …
RyanJDick Mar 6, 2025
fdcbbaa
Switch to sequential CFG for CogView4 (for now, until I sort out the …
RyanJDick Mar 6, 2025
5113ee9
Fix bug in CogView4 noise schedule handling that was resulting in low…
RyanJDick Mar 7, 2025
535a287
Add CogView4ImageToLatentsInvocation.
RyanJDick Mar 10, 2025
aff50c4
Simplify CogView4 timesteps schedule generation in preparation for ti…
RyanJDick Mar 10, 2025
90671a6
Update CogView4Denoise to support image-to-image.
RyanJDick Mar 10, 2025
adb5210
Support cfg_scale list in CogView4Denoise.
RyanJDick Mar 10, 2025
19f384b
Consolidate InpaintExtension implementations for SD3 and FLUX.
RyanJDick Mar 10, 2025
37f24c5
Add inpainting to CogView4DenoiseInvocation.
RyanJDick Mar 11, 2025
c09f44a
Add CogView4 VAE approximation for progress images.
RyanJDick Mar 11, 2025
16a1133
typegen
RyanJDick Mar 14, 2025
5c0149e
Fix lint error.
RyanJDick Mar 14, 2025
088a814
add generation modes for cogview linear
maryhipp Mar 17, 2025
ce7562c
build graph for cogview4
Mar 17, 2025
15bd258
create hook for managing entity type enabledness for given base model…
Mar 17, 2025
045fbf6
update available params for cogview4
Mar 17, 2025
885999f
feat(app): add cogview4 default workflows
psychedelicious Mar 18, 2025
aa17ede
feat(ui): add cogview4 and inpainting tags to library
psychedelicious Mar 18, 2025
f0bfa7c
feat(nodes): rename CogView4 nodes to match naming format
psychedelicious Mar 18, 2025
bcfe495
fix(ui): add checks for cogview4's dimension restrictions
psychedelicious Mar 18, 2025
42a358c
refactor(ui): simplify useIsEntityTypeEnabled
psychedelicious Mar 18, 2025
0b862ca
chore(ui): lint
psychedelicious Mar 18, 2025
253bb58
chore(ui): typegen
psychedelicious Mar 18, 2025
574757a
feat(app): update cogview4 t2i workflow w/ form
psychedelicious Mar 18, 2025
90e95ea
feat(app): remove cogview4 inpaint workflow
psychedelicious Mar 18, 2025
efab655
fix(app): add CogView4ConditioningInfo to ObjectSerializerDisk's safe…
psychedelicious Apr 9, 2025
3fdccbe
fix(app): import CogView4Transformer2DModel from the module that expo…
psychedelicious Apr 9, 2025
cac974d
fix(ui): generation_mode metadata not set correctly
psychedelicious Apr 9, 2025
5e3dce1
fix(ui): do not disable image context canvas actions based on selecte…
psychedelicious Apr 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
CogView4ConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
SD3ConditioningInfo,
Expand Down Expand Up @@ -123,6 +124,7 @@ def initialize(
SDXLConditioningInfo,
FLUXConditioningInfo,
SD3ConditioningInfo,
CogView4ConditioningInfo,
],
ephemeral=True,
),
Expand Down
363 changes: 363 additions & 0 deletions invokeai/app/invocations/cogview4_denoise.py

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions invokeai/app/invocations/cogview4_image_to_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice

# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
# should refactor to avoid this duplication.


@invocation(
"cogview4_i2l",
title="Image to Latents - CogView4",
tags=["image", "latents", "vae", "i2l", "cogview4"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""

image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)

vae.disable_tiling()

image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)

latents = vae.config.scaling_factor * latents

return latents

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)

image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
86 changes: 86 additions & 0 deletions invokeai/app/invocations/cogview4_latents_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from contextlib import nullcontext

import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice

# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
# should refactor to avoid this duplication.


@invocation(
"cogview4_l2i",
title="Latents to Image - CogView4",
tags=["latents", "image", "vae", "l2i", "cogview4"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""

latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(TorchDevice.choose_torch_device())

vae.disable_tiling()

tiling_context = nullcontext()

# clear memory as vae decode can request a lot
TorchDevice.empty_cache()

with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
img = vae.decode(latents, return_dict=False)[0]

img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())

TorchDevice.empty_cache()

image_dto = context.images.save(image=img_pil)

return ImageOutput.build(image_dto)
55 changes: 55 additions & 0 deletions invokeai/app/invocations/cogview4_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
GlmEncoderField,
ModelIdentifierField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType


@invocation_output("cogview4_model_loader_output")
class CogView4ModelLoaderOutput(BaseInvocationOutput):
"""CogView4 base model loader output."""

transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
glm_encoder: GlmEncoderField = OutputField(description=FieldDescriptions.glm_encoder, title="GLM Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")


@invocation(
"cogview4_model_loader",
title="Main Model - CogView4",
tags=["model", "cogview4"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4ModelLoaderInvocation(BaseInvocation):
"""Loads a CogView4 base model, outputting its submodels."""

model: ModelIdentifierField = InputField(
description=FieldDescriptions.cogview4_model,
ui_type=UIType.CogView4MainModel,
input=Input.Direct,
)

def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
glm_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
glm_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})

return CogView4ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
glm_encoder=GlmEncoderField(tokenizer=glm_tokenizer, text_encoder=glm_encoder),
vae=VAEField(vae=vae),
)
92 changes: 92 additions & 0 deletions invokeai/app/invocations/cogview4_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from transformers import GlmModel, PreTrainedTokenizerFast

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.model import GlmEncoderField
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
CogView4ConditioningInfo,
ConditioningFieldData,
)
from invokeai.backend.util.devices import TorchDevice

# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
COGVIEW4_GLM_MAX_SEQ_LEN = 1024


@invocation(
"cogview4_text_encoder",
title="Prompt - CogView4",
tags=["prompt", "conditioning", "cogview4"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a cogview4 image."""

prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
glm_encoder: GlmEncoderField = InputField(
title="GLM Encoder",
description=FieldDescriptions.glm_encoder,
input=Input.Connection,
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return CogView4ConditioningOutput.build(conditioning_name)

def _glm_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
prompt = [self.prompt]

# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
with (
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
):
context.util.signal_progress("Running GLM text encoder")
assert isinstance(glm_text_encoder, GlmModel)
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)

text_inputs = glm_tokenizer(
prompt,
padding="longest",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = glm_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = glm_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)

current_length = text_input_ids.shape[1]
pad_length = (16 - (current_length % 16)) % 16
if pad_length > 0:
pad_ids = torch.full(
(text_input_ids.shape[0], pad_length),
fill_value=glm_tokenizer.pad_token_id,
dtype=text_input_ids.dtype,
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = glm_text_encoder(
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
).hidden_states[-2]

assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
9 changes: 9 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):

# region Model Field Types
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
Expand Down Expand Up @@ -137,6 +138,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
Expand All @@ -151,6 +153,7 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
cogview4_model = "CogView4 model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
Expand Down Expand Up @@ -290,6 +293,12 @@ class SD3ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")


class CogView4ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

conditioning_name: str = Field(description="The name of conditioning tensor")


class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
Expand All @@ -53,6 +52,7 @@
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -295,10 +295,10 @@ def _run_diffusion(
assert packed_h * packed_w == x.shape[1]

# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
inpaint_extension: RectifiedFlowInpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
inpaint_extension = RectifiedFlowInpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
Expand Down
4 changes: 4 additions & 0 deletions invokeai/app/invocations/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def invoke(self, context: InvocationContext) -> MetadataOutput:
"sd3_img2img",
"sd3_inpaint",
"sd3_outpaint",
"cogview4_txt2img",
"cogview4_img2img",
"cogview4_inpaint",
"cogview4_outpaint",
]


Expand Down
Loading
Loading