Skip to content

Conversation

@E-Anlia
Copy link

@E-Anlia E-Anlia commented Dec 12, 2025

🧩 What does this PR do?

This PR adds native support for NewBie image models, a Next-DiT based text-to-image architecture, to ComfyUI.

NewBie models are DiT-style (Flow-based) transformers, inspired by Lumina / Next-DiT research, but they are not compatible with existing Lumina or SD-style UNet assumptions.

This PR introduces a dedicated model class and loading path so that NewBie models can be used without modifying or breaking any existing models.

🧠 Why is this needed?

Previously, running NewBie models in ComfyUI required local forks or heavy monkey-patching, often by modifying Lumina-related code paths.

This PR:

Avoids modifying Lumina or any existing model logic

Introduces a clean, isolated NewBie model implementation

Matches the inference behavior that has already been validated in production via custom nodes

🔒 Scope & safety

This PR is intentionally conservative:

✅ No changes to existing Models behavior

✅ No changes to shared attention or sampling utilities

✅ All NewBie logic is isolated under a new model class

✅ If a checkpoint is not detected as NewBie, behavior is unchanged

🧪 Testing

Verified loading and inference with NewBie image models

Confirmed correct timestep direction and conditioning behavior

Confirmed no regression when running existing models

📌 Notes

This PR does not attempt to refactor or optimize existing model code.
Its goal is solely to provide first-class support for a new DiT-based architecture that is already used by the community.

@GumGum10
Copy link

I have tested this on latest ComfyUI version and got an error:
TypeError: 'function' object is not iterable

I was able to resolve the above error after installing flash-attn2:
https://huggingface.co/ussoewwin/Flash-Attention-2_for_Windows/tree/main

Now the model is running without any issues for me, torch compile also works
832/1264 36 steps 24s [1.45it/s] on a 4090

@bombdefuser-124
Copy link

Working great here too o/ !

@isaac-mcfadyen
Copy link

isaac-mcfadyen commented Dec 12, 2025

AFAIK nothing else in Comfy core is has a dependency on Flash Attention (i.e. I prefer SDPA/SageAttention so do not even have FA installed). See for example this file which allows trying an FA import and falling back to other implementations:

FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)

Maybe this could be refactored to use whatever generic Attention primitives Comfy provides and remove the hard dependency on FA?

for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this line, expressions such as H // pH and W // pW rely on integer floor division. When H or W is an odd number, the floor division silently discards the remainder, causing a mismatch between the target shape and the actual tensor size, which leads to a runtime error.

This situation commonly occurs after an upscaling step, where the resulting height or width may become odd, and therefore requires explicit handling (e.g. padding or cropping) before patchification.

  File "H:\ComfyUI-aki-v1.7_alt\ComfyUI\comfy\samplers.py", line 214, in _calc_cond_batch_outer
    return executor.execute(model, conds, x_in, timestep, model_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\ComfyUI\comfy\patcher_extension.py", line 112, in execute
    return self.original(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\ComfyUI\comfy\samplers.py", line 326, in _calc_cond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\ComfyUI\comfy\model_base.py", line 1009, in apply_model
    model_output = self.diffusion_model(xc, t_val, cap_feats, cap_mask, **model_kwargs).float()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\python\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\python\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\ComfyUI\comfy\ldm\newbie\model.py", line 932, in forward
    x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\ComfyUI-aki-v1.7_alt\ComfyUI\comfy\ldm\newbie\model.py", line 726, in patchify_and_embed
    img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review, I am fix this bug.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the patchify crash is fixed. The old implementation reshaped using H//ph and W//pw, which fails when the image height/width isn’t divisible by the patch size (common after upscaling). The updated version pads inputs to the nearest patch-size multiple before patchify and crops back to the original size after unpatchify, so it no longer errors on odd/non-divisible resolutions.

In addition, I also refactored the NewBie integration to reuse Lumina NextDiT backbone (NewBie only keeps the next clip specific logic), normalized the loader interface by sanitizing extra kwargs and correctly propagating device/dtype/operations, and fixed dtype/device mismatches to improve stability and performance during inference.

@AyanamiRei52020
Copy link

My K sampler reported these errors: ...!!! Exception during processing !!! Expected all tensors to be on the same device, but got mat1 is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA_addmm) Traceback (most recent call last): File "A:\ComfyUI-aki-v2\ComfyUI\execution.py", line 515, in execute output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\execution.py", line 329, in get_output_data return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\execution.py", line 303, in _async_map_node_over_list await process_inputs(input_dict, i) File "A:\ComfyUI-aki-v2\ComfyUI\execution.py", line 291, in process_inputs result = f(**inputs) ^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\nodes.py", line 1538, in sample return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\nodes.py", line 1505, in common_ksampler samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\sample.py", line 60, in sample samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 1163, in sample return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 1053, in sample return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 1035, in sample output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\patcher_extension.py", line 112, in execute return self.original(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 997, in outer_sample output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 980, in inner_sample samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\patcher_extension.py", line 112, in execute return self.original(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 752, in sample samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\utils\_contextlib.py", line 120, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\k_diffusion\sampling.py", line 1429, in sample_res_multistep return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\utils\_contextlib.py", line 120, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\k_diffusion\sampling.py", line 1387, in res_multistep denoised = model(x, sigmas[i] * s_in, **extra_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 401, in __call__ out = self.inner_model(x, sigma, model_options=model_options, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 953, in __call__ return self.outer_predict_noise(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 960, in outer_predict_noise ).execute(x, timestep, model_options, seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\patcher_extension.py", line 112, in execute return self.original(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 963, in predict_noise return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 381, in sampling_function out = calc_cond_batch(model, conds, x, timestep, model_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 206, in calc_cond_batch return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 214, in _calc_cond_batch_outer return executor.execute(model, conds, x_in, timestep, model_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\patcher_extension.py", line 112, in execute return self.original(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\samplers.py", line 326, in _calc_cond_batch output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\model_base.py", line 1009, in apply_model model_output = self.diffusion_model(xc, t_val, cap_feats, cap_mask, **model_kwargs).float() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\ldm\newbie\model.py", line 927, in forward t_emb = self.t_embedder(t) ^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\ComfyUI\comfy\ldm\newbie\model.py", line 83, in forward t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\container.py", line 250, in forward input = module(input) ^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "A:\ComfyUI-aki-v2\python\Lib\site-packages\torch\nn\modules\linear.py", line 134, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected all tensors to be on the same device, but got mat1 is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA_addmm)

@DraconicDragon
Copy link

DraconicDragon commented Dec 14, 2025

It seems like, for me at least, --fast fp16_accumulation causes the model to produce NaNs on the commit before the reusing of existing nextdit code happened and after that (current commit) the cli flag causes the model to just output noise
Also, Gemma 3 4B and Jina Clip implementation would be appreciated

@woct0rdho
Copy link

woct0rdho commented Dec 15, 2025

Hi, I'm also trying to add NewBie to ComfyUI in a simpler and more ComfyUI-idiomatic way. I believe it could be the best anime model we have at hand before any anime finetune of Z-Image or Qwen-Image appears.

I think we can just follow the direction of #11172 and do minimal changes to the existing Lumina2 class, then we don't need to define a new class for NewBie. (Update: After discussing with the NewBie devs, they think a new class for NewBie may be easier for extensions in future.)

Here is a generated image, with a simple workflow in it, without any custom node.

ComfyUI_00038_

ComfyUI features such as SageAttention and memory swap just work.

An interesting finding: Lumina2 sets axes_lens = [300, 512, 512], while NewBie sets axes_lens = [1024, 512, 512] (and Z-Image sets axes_lens = [1536, 512, 512]). But it's not used in the model, because ComfyUI does not precompute freqs_cis for all lengths.

TODO

Sliding attention in Gemma is not implemented yet, see

logging.warning("Warning: sliding attention not implemented, results may be incorrect")

so the result is not fully correct when the prompt is longer than 1024 tokens.

@urlesistiana

This comment was marked as resolved.

@woct0rdho

This comment was marked as resolved.

@woct0rdho
Copy link

woct0rdho commented Dec 18, 2025

Update: I've implemented Jina CLIP v2 in a ComfyUI-idiomatic way at https://github.com/woct0rdho/ComfyUI/tree/newbie . The architecture is put in a single py file, and the weights are also packaged in a single file at https://huggingface.co/woctordho/comfyui-jina-clip-v2 . It does not depend on Transformers and does not download anything from the internet. I've tested that it produces the same clip_text_pooled as the official Jina CLIP v2 (within the floating point error of bf16).

Here is an image generated with my forked ComfyUI at commit woct0rdho@98b25d4 , with both Gemma and Jina conditioning:

ComfyUI_00039_

@E-Anlia You can copy my code if needed. Or you can let me open a separate PR that adds Jina CLIP v2 to ComfyUI if it's more convenient.

@SakanakoChan
Copy link

SakanakoChan commented Dec 18, 2025

Update: I've implemented Jina CLIP v2 in a ComfyUI-idiomatic way at https://github.com/woct0rdho/ComfyUI/tree/newbie . The architecture is put in a single py file, and the weights are also packaged in a single file at https://huggingface.co/woctordho/comfyui-jina-clip-v2 . It does not depend on Transformers and does not download anything from the internet. I've tested that it produces the same clip_text_pooled as the official Jina CLIP v2 (within the floating point error of bf16).

Here is an image generated with my forked ComfyUI at commit woct0rdho@98b25d4 , with both Gemma and Jina conditioning:

ComfyUI_00039_ @E-Anlia You can copy my code if needed. Or you can let me open a separate PR that adds Jina CLIP v2 to ComfyUI if it's more convenient.

@woct0rdho Thanks for your great work! I tried cherry picking your commit [woct0rdho/ComfyUI@98b25d4] and successfully run newbie model with dual clip Gemma and Jina. But there seems to be something wrong in the output image. See it below
rgthree compare _temp_axkmj_00003_
(You can drag this image to comfyui page to see the workflow)
When I tried to do latent upscale or face detailer in the same way as I did to SDXL models before, there tends to be some noises, especially around the face area. Is this due to newbie image its own model issue or because Gemma is not fully implemented with sliding attention? Or it's just simply because the workflow is not compatible with newbie?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants