Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
295cdb1
Implementation of VAE for Hunyuan 3D 2.1
yousef-rafat Jun 27, 2025
8ecad4c
Added MoE file for the Hunyuan 3D 2.1 Model
yousef-rafat Jun 29, 2025
b445dd3
Update moe.py
yousef-rafat Jun 29, 2025
6515519
Added Pipeline, Conditioner, Diffusion, Scheduler, and Image Processor
yousef-rafat Jul 5, 2025
cd94296
Merge branch 'yousef'
yousef-rafat Jul 5, 2025
b3839ca
Added Pipeline, Conditioner, Diffusion, Scheduler,
yousef-rafat Jul 5, 2025
1746550
fixed some bugs and rewrote OpenCV resize funcs
yousef-rafat Jul 7, 2025
5b24694
testing and small fixes
yousef-rafat Jul 8, 2025
84bcc09
removed trimesh, replacing with native impl.
yousef-rafat Jul 8, 2025
3ac2861
fixed a bug in saving
yousef-rafat Jul 8, 2025
f49e4b5
integerated hunyuan3dv2_1
yousef-rafat Jul 10, 2025
b184a61
some fixes
yousef-rafat Jul 10, 2025
dff570e
file management and removed code redundancy
yousef-rafat Jul 11, 2025
ee65d6e
added dino2 large support and some fixes
yousef-rafat Jul 12, 2025
491a49c
merged vaes and improved surface net
yousef-rafat Jul 13, 2025
db06eeb
compatibility with previous nodes + block replace
yousef-rafat Jul 17, 2025
170c3e0
rm
yousef-rafat Jul 17, 2025
6316bb4
final changes
yousef-rafat Jul 22, 2025
8c78799
Merge branch 'master' into yousef
yousef-rafat Jul 29, 2025
f2a2d6b
style changes
yousef-rafat Jul 30, 2025
6ca9c64
Merge branch 'yousef' of https://github.com/yousef-rafat/ComfyUI into…
yousef-rafat Jul 30, 2025
1d92312
styling
yousef-rafat Jul 30, 2025
c39c3ff
converted to comfy api
yousef-rafat Jul 31, 2025
a1c50b6
device error fix
yousef-rafat Jul 31, 2025
447a46f
fix attention
yousef-rafat Aug 1, 2025
8ede2b6
Merge branch 'master' into yousef
yousef-rafat Aug 31, 2025
195e78a
Merge branch 'master' into yousef
comfyanonymous Sep 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 226 additions & 5 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,227 @@ def __getitem__(self, key):
def __setitem__(self, key, item):
setattr(self, key, item)

def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):

def cubic_kernel(x, a: float = -0.75):
absx = x.abs()
absx2 = absx ** 2
absx3 = absx ** 3

w = (a + 2) * absx3 - (a + 3) * absx2 + 1
w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a

return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x)))

def get_indices_weights(in_size, out_size, scale):
# OpenCV-style half-pixel mapping
x = torch.arange(out_size, dtype=torch.float32)
x = (x + 0.5) / scale - 0.5

x0 = x.floor().long()
dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3))

weights = cubic_kernel(dx)
weights = weights / weights.sum(dim=1, keepdim=True)

indices = x0.unsqueeze(1) + torch.arange(-1, 3)
indices = indices.clamp(0, in_size - 1)

return indices, weights

def resize_cubic_1d(x, out_size, dim):
b, c, h, w = x.shape
in_size = h if dim == 2 else w
scale = out_size / in_size

indices, weights = get_indices_weights(in_size, out_size, scale)

if dim == 2:
x = x.permute(0, 1, 3, 2)
x = x.reshape(-1, h)
else:
x = x.reshape(-1, w)

gathered = x[:, indices]
out = (gathered * weights.unsqueeze(0)).sum(dim=2)

if dim == 2:
out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2)
else:
out = out.reshape(b, c, h, out_size)

return out

def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor:
"""
Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
Implemented in pure PyTorch
"""

if img.ndim == 3:
img = img.unsqueeze(0)

img = img.permute(0, 3, 1, 2)

out_h, out_w = size
img = resize_cubic_1d(img, out_h, dim=2)
img = resize_cubic_1d(img, out_w, dim=3)
return img

def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor:
# vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
original_shape = img.shape
is_hwc = False

if img.ndim == 3:
if img.shape[0] <= 4:
img = img.unsqueeze(0)
else:
is_hwc = True
img = img.permute(2, 0, 1).unsqueeze(0)
elif img.ndim == 4:
pass
else:
raise ValueError("Expected image with 3 or 4 dims.")

B, C, H, W = img.shape
out_h, out_w = size
scale_y = H / out_h
scale_x = W / out_w

device = img.device

# compute the grid boundries
y_start = torch.arange(out_h, device=device).float() * scale_y
y_end = y_start + scale_y
x_start = torch.arange(out_w, device=device).float() * scale_x
x_end = x_start + scale_x

# for each output pixel, we will compute the range for it
y_start_int = torch.floor(y_start).long()
y_end_int = torch.ceil(y_end).long()
x_start_int = torch.floor(x_start).long()
x_end_int = torch.ceil(x_end).long()

# We will build the weighted sums by iterating over contributing input pixels once
output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device)
area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device)

max_kernel_h = int(torch.max(y_end_int - y_start_int).item())
max_kernel_w = int(torch.max(x_end_int - x_start_int).item())

for dy in range(max_kernel_h):
for dx in range(max_kernel_w):
# compute the weights for this offset for all output pixels

y_idx = y_start_int.unsqueeze(1) + dy
x_idx = x_start_int.unsqueeze(0) + dx

# clamp indices to image boundaries
y_idx_clamped = torch.clamp(y_idx, 0, H - 1)
x_idx_clamped = torch.clamp(x_idx, 0, W - 1)

# compute weights by broadcasting
y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0)
x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0)

weight = (y_weight * x_weight)

y_expand = y_idx_clamped.expand(out_h, out_w)
x_expand = x_idx_clamped.expand(out_h, out_w)


pixels = img[:, :, y_expand, x_expand]

# unsqueeze to broadcast
w = weight.unsqueeze(0).unsqueeze(0)

output += pixels * w
area += weight

# Normalize by area
output /= area.unsqueeze(0).unsqueeze(0)

if is_hwc:
return output[0].permute(1, 2, 0)
elif img.shape[0] == 1 and original_shape[0] <= 4:
return output[0]
else:
return output

def recenter(image, border_ratio: float = 0.2):

if image.shape[-1] == 4:
mask = image[..., 3]
else:
mask = torch.ones_like(image[..., 0:1]) * 255
image = torch.concatenate([image, mask], axis=-1)
mask = mask[..., 0]

H, W, C = image.shape

size = max(H, W)
result = torch.zeros((size, size, C), dtype = torch.uint8)

# as_tuple to match numpy behaviour
x_coords, y_coords = torch.nonzero(mask, as_tuple=True)

y_min, y_max = y_coords.min(), y_coords.max()
x_min, x_max = x_coords.min(), x_coords.max()

h = x_max - x_min
w = y_max - y_min

if h == 0 or w == 0:
raise ValueError('input image is empty')

desired_size = int(size * (1 - border_ratio))
scale = desired_size / max(h, w)

h2 = int(h * scale)
w2 = int(w * scale)

x2_min = (size - h2) // 2
x2_max = x2_min + h2

y2_min = (size - w2) // 2
y2_max = y2_min + w2

# note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2))

bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255

mask = result[..., 3:].to(torch.float32) / 255
result = result[..., :3] * mask + bg * (1 - mask)

mask = mask * 255
result = result.clip(0, 255).to(torch.uint8)
mask = mask.clip(0, 255).to(torch.uint8)

return result

def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711],
crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512):

if border_ratio is not None:

image = (image * 255).clamp(0, 255).to(torch.uint8)
image = [recenter(img, border_ratio = border_ratio) for img in image]

image = torch.stack(image, dim = 0)
image = resize_cubic(image, size = (recenter_size, recenter_size))

image = image / 255 * 2 - 1
low, high = value_range

image = (image - low) / (high - low)
image = image.permute(0, 2, 3, 1)

image = image[:, :, :, :3] if image.shape[3] > 3 else image

mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)

image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
Expand All @@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
else:
scale_size = (size, size)

image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
Expand Down Expand Up @@ -71,9 +288,9 @@ def load_sd(self, sd):
def get_sd(self):
return self.model.state_dict()

def encode_image(self, image, crop=True):
def encode_image(self, image, crop=True, border_ratio: float = None):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)

outputs = Output()
Expand Down Expand Up @@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:

# Dinov2
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None

Expand Down
33 changes: 26 additions & 7 deletions comfy/image_encoders/dino2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def __init__(self, dim, dtype, device, operations):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)

class Dinov2MLP(torch.nn.Module):
def __init__(self, hidden_size: int, dtype, device, operations):
super().__init__()

mlp_ratio = 4
hidden_features = int(hidden_size * mlp_ratio)
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)

def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = torch.nn.functional.gelu(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state

class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
Expand All @@ -50,12 +64,15 @@ def forward(self, x):


class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
if use_swiglu_ffn:
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
else:
self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)

Expand All @@ -66,9 +83,10 @@ def forward(self, x, optimized_attention):


class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])

def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
Expand All @@ -78,8 +96,8 @@ def forward(self, x, intermediate_output=None):
intermediate_output = len(self.layer) + intermediate_output

intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
for i, layer in enumerate(self.layer):
x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
Expand Down Expand Up @@ -128,9 +146,10 @@ def __init__(self, config_dict, dtype, device, operations):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]

self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)

def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
Expand Down
22 changes: 22 additions & 0 deletions comfy/image_encoders/dino2_large.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"hidden_size": 1024,
"use_mask_token": true,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_attention_heads": 16,
"initializer_range": 0.02,
"attention_probs_dropout_prob": 0.0,
"hidden_dropout_prob": 0.0,
"hidden_act": "gelu",
"mlp_ratio": 4,
"model_type": "dinov2",
"num_hidden_layers": 24,
"layer_norm_eps": 1e-6,
"qkv_bias": true,
"use_swiglu_ffn": false,
"layerscale_value": 1.0,
"drop_path_rate": 0.0,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}
5 changes: 5 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ class Hunyuan3Dv2(LatentFormat):
latent_dimensions = 1
scale_factor = 0.9990943042622529

class Hunyuan3Dv2_1(LatentFormat):
scale_factor = 1.0039506158752403
latent_channels = 64
latent_dimensions = 1

class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
Expand Down
Loading
Loading